From 5981d31249405151df145d6fadf6dc6110c50a35 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 17 May 2023 21:06:50 -0400 Subject: [PATCH 001/166] Update installation.rst more rst formatting --- doc/installation.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/installation.rst b/doc/installation.rst index 80452a60e7..acc5117249 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -38,7 +38,7 @@ From source As :code:`spikeinterface` is undergoing a heavy development phase, it is sometimes convenient to install from source to get the latest bug fixes and improvements. We recommend constructing the package within a -[virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/) +`virtual environment `_ to prevent potential conflicts with local dependencies. .. code-block:: bash @@ -49,7 +49,7 @@ to prevent potential conflicts with local dependencies. pip install -e . cd .. -Note that this will install the package in [editable mode](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs). +Note that this will install the package in `editable mode `_. It is also recommended in that case to also install :code:`neo` and :code:`probeinterface` from source, as :code:`spikeinterface` strongly relies on these packages to interface with various formats and handle probes: From 08b634d29f37e4d83b8cea9b4220f2ef69893a5d Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Thu, 18 May 2023 17:00:31 -0400 Subject: [PATCH 002/166] more typo fixes --- doc/modules/core.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 079e1aa0d3..ca658644ea 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -190,7 +190,7 @@ The :py:class:`~spikeinterface.core.WaveformExtractor` class is the core object Waveforms are very important for additional analysis, and the basis of several postprocessing and quality metrics computations. -The :py:class:`~spikeinterface.core.WaveformExtractor` allows to: +The :py:class:`~spikeinterface.core.WaveformExtractor` allows us to: * extract and waveforms * sub-sample spikes for waveform extraction @@ -199,7 +199,7 @@ The :py:class:`~spikeinterface.core.WaveformExtractor` allows to: * save sparse waveforms or *sparsify* dense waveforms * select units and associated waveforms -The default format (:code:`mode='folder'`) which waveforms are saved to is a folder structure with waveforms as +The default format (:code:`mode='folder'`) which waveforms are saved to a folder structure with waveforms as :code:`.npy` files. In addition, waveforms can also be extracted in-memory for fast computations (:code:`mode='memory'`). Note that this mode can quickly fill up your RAM... Use it wisely! @@ -232,7 +232,7 @@ Finally, an existing :py:class:`~spikeinterface.core.WaveformExtractor` can be s # (this can also be done within the 'extract_waveforms') we.precompute_templates(modes=("std",)) - # retrieve all template means and standard devs + # retrieve all template means and standard deviations template_means = we.get_all_templates(mode="average") template_stds = we.get_all_templates(mode="std") @@ -471,7 +471,7 @@ Parallel processing and job_kwargs The :py:mod:`~spikeinterface.core` module also contains the basic tools used throughout SpikeInterface for parallel processing of recordings. -In general, parallelization is achieved by splitting the recording in many small time chunks and process +In general, parallelization is achieved by splitting the recording in many small time chunks and processing them in parallel (for more details, see the :py:class:`~spikeinterface.core.ChunkRecordingExecutor` class). Many functions support parallel processing (e.g., :py:func:`~spikeinterface.core.extract_waveforms`, :code:`save`, @@ -494,11 +494,11 @@ These are a set of keyword arguments which are common to all functions that supp If True, a progress bar is printed * mp_context: str or None Context for multiprocessing. It can be None (default), "fork" or "spawn". - Note that "fork" is only available on UNIX systems + Note that "fork" is only available on UNIX systems (not Windows) The default **job_kwargs** are :code:`n_jobs=1, chunk_duration="1s", progress_bar=True`. -Any of these argument, can be overridden by manually passing the argument to a function +Any of these arguments, can be overridden by manually passing the argument to a function (e.g., :code:`extract_waveforms(..., n_jobs=16)`). Alternatively, **job_kwargs** can be set globally (for each SpikeInterface session), with the :py:func:`~spikeinterface.core.set_global_job_kwargs` function: @@ -688,13 +688,13 @@ The :py:mod:`spikeinterface.core.template_tools` submodule includes functionalit Generate toy objects -------------------- -The :py:mod:`~spikeinterface.core` module also offers some functions to generate toy/fake data. +The :py:mod:`~spikeinterface.core` module also offers some functions to generate toy/simulated data. They are useful to make examples, tests, and small demos: .. code-block:: python # recording with 2 segments and 4 channels - recording = generate_recording(generate_recording(num_channels=4, sampling_frequency=30000., + recording = generate_recording(num_channels=4, sampling_frequency=30000., durations=[10.325, 3.5], set_probe=True) # sorting with 2 segments and 5 units From 6a99199e6c1d35b4a65818e7d0b6228ee46e4b11 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 13:51:10 +0200 Subject: [PATCH 003/166] Refactor NumpySorting : use internal numpy spike_vector. Also NumpySorting.from_dict() > NumpySorting.from_unit_dict() because was overwriting base. --- .../comparison/generate_erroneous_sorting.py | 2 +- src/spikeinterface/comparison/hybrid.py | 2 +- src/spikeinterface/core/basesorting.py | 4 +- src/spikeinterface/core/generate.py | 6 +- src/spikeinterface/core/numpyextractors.py | 159 ++++++++++++++---- .../core/tests/test_basesorting.py | 4 +- .../core/tests/test_frameslicesorting.py | 6 +- .../core/tests/test_numpy_extractors.py | 4 +- .../core/tests/test_waveform_extractor.py | 2 +- src/spikeinterface/core/waveform_extractor.py | 2 +- .../curation/tests/test_curationsorting.py | 4 +- .../tests/test_align_sorting.py | 2 +- .../postprocessing/tests/test_correlograms.py | 4 +- .../tests/test_metrics_functions.py | 2 +- .../sorters/external/mountainsort4.py | 2 +- 15 files changed, 147 insertions(+), 58 deletions(-) diff --git a/examples/modules_gallery/comparison/generate_erroneous_sorting.py b/examples/modules_gallery/comparison/generate_erroneous_sorting.py index b5f53e71ee..d62a15bdc0 100644 --- a/examples/modules_gallery/comparison/generate_erroneous_sorting.py +++ b/examples/modules_gallery/comparison/generate_erroneous_sorting.py @@ -88,7 +88,7 @@ def generate_erroneous_sorting(): for u in [15,16,17]: st = np.sort(np.random.randint(0, high=nframes, size=35)) units_err[u] = st - sorting_err = se.NumpySorting.from_dict(units_err, sampling_frequency) + sorting_err = se.NumpySorting.from_unit_dict(units_err, sampling_frequency) return sorting_true, sorting_err diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index 172257c3f1..ae19db3c8f 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -241,7 +241,7 @@ def generate_injected_sorting( injected_spike_trains[segment_index][unit_id] = injected_spike_train - return NumpySorting.from_dict(injected_spike_trains, sorting.get_sampling_frequency()) + return NumpySorting.from_unit_dict(injected_spike_trains, sorting.get_sampling_frequency()) create_hybrid_units_recording = define_function_from_class( diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8504624f3b..d022fb5a07 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -7,6 +7,8 @@ from .waveform_tools import has_exceeding_spikes +minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. @@ -362,7 +364,7 @@ def to_spike_vector(self, extremum_channel_inds=None): spikes_ = self.get_all_spike_trains(outputs="unit_index") n = np.sum([e[0].size for e in spikes_]) - spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + spike_dtype = minimum_spike_dtype if extremum_channel_inds is not None: spike_dtype += [("channel_index", "int64")] diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index f77982fd1e..123e2f0bdf 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -107,7 +107,7 @@ def generate_sorting( else: units_dict[unit_id] = np.array([], dtype=int) units_dict_list.append(units_dict) - sorting = NumpySorting.from_dict(units_dict_list, sampling_frequency) + sorting = NumpySorting.from_unit_dict(units_dict_list, sampling_frequency) return sorting @@ -319,7 +319,7 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No d[unit_id] = times spiketrains.append(d) - sorting_with_dup = NumpySorting.from_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) + sorting_with_dup = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) return sorting_with_dup @@ -357,7 +357,7 @@ def inject_some_split_units(sorting, split_ids=[], num_split=2, output_ids=False new_units[unit_id] = original_times spiketrains.append(new_units) - sorting_with_split = NumpySorting.from_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) + sorting_with_split = NumpySorting.from_unit_dict(spiketrains, sampling_frequency=sorting.get_sampling_frequency()) if output_ids: return sorting_with_split, other_ids else: diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index ac3478f482..e9c71b29b7 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -9,9 +9,13 @@ BaseSnippets, BaseSnippetsSegment, ) +from .basesorting import minimum_spike_dtype + from typing import List, Union + + class NumpyRecording(BaseRecording): """ In memory recording. @@ -93,30 +97,53 @@ def get_traces(self, start_frame, end_frame, channel_indices): class NumpySorting(BaseSorting): - name = "numpy" + """ + In memory sorting object. + The internal representation is always done with a long "spike vector". - def __init__(self, sampling_frequency, unit_ids=[]): - BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.is_dumpable = False - @staticmethod - def from_extractor(source_sorting: BaseSorting) -> "NumpySorting": + But we have convinient function to instantiate from other sorting object, from time+labels, + from dict of list or from neo. + + Parameters + ---------- + spikes: numpy.array + A numpy vector, the one given by Sorting.to_spike_vector(). + sampling_frequency: float + The sampling frequency in Hz + channel_ids: list + A list of unit_ids. + """ + name = "numpy" + + def __init__(self, spikes, sampling_frequency, unit_ids): """ - Create a numpy sorting from another extractor + """ - unit_ids = source_sorting.get_unit_ids() - nseg = source_sorting.get_num_segments() + BaseSorting.__init__(self, sampling_frequency, unit_ids) + self.is_dumpable = True - sorting = NumpySorting(source_sorting.get_sampling_frequency(), unit_ids) + if spikes.size == 0: + nseg = 0 + else: + nseg = spikes[-1]['segment_index'] + 1 for segment_index in range(nseg): - units_dict = {} - for unit_id in unit_ids: - units_dict[unit_id] = source_sorting.get_unit_spike_train(unit_id, segment_index) - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) - sorting.copy_metadata(source_sorting) + self._kwargs = dict(spikes=spikes, sampling_frequency=sampling_frequency, unit_ids=unit_ids) + @staticmethod + def from_extractor(source_sorting: BaseSorting, with_metadata=False) -> "NumpySorting": + """ + Create a numpy sorting from another extractor + """ + + sorting = NumpySorting(source_sorting.to_spike_vector(), + source_sorting.get_sampling_frequency(), + source_sorting.unit_ids) + if with_metadata: + sorting.copy_metadata(source_sorting) return sorting @staticmethod @@ -146,22 +173,32 @@ def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None labels_list = [np.asarray(e) for e in labels_list] nseg = len(times_list) + if unit_ids is None: unit_ids = np.unique(np.concatenate([np.unique(labels_list[i]) for i in range(nseg)])) - sorting = NumpySorting(sampling_frequency, unit_ids) + + spikes = [] for i in range(nseg): - units_dict = {} times, labels = times_list[i], labels_list[i] - for unit_id in unit_ids: - mask = labels == unit_id - units_dict[unit_id] = times[mask] - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + unit_index = np.zeros(labels.size, dtype='int64') + for u, unit_id in enumerate(unit_ids): + unit_index[labels == unit_id] = u + spikes_in_seg = np.zeros(len(times), dtype=minimum_spike_dtype) + spikes_in_seg['sample_index'] = times + spikes_in_seg['unit_index'] = unit_index + spikes_in_seg['segment_index'] = i + order = np.argsort(times) + spikes_in_seg = spikes_in_seg[order] + spikes.append(spikes_in_seg) + spikes = np.concatenate(spikes) + + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) return sorting @staticmethod - def from_dict(units_dict_list, sampling_frequency) -> "NumpySorting": + def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": """ Construct sorting extractor from a list of dict. The list length is the segment count @@ -176,11 +213,35 @@ def from_dict(units_dict_list, sampling_frequency) -> "NumpySorting": unit_ids = list(units_dict_list[0].keys()) - sorting = NumpySorting(sampling_frequency, unit_ids) - for i, units_dict in enumerate(units_dict_list): - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + nseg = len(units_dict_list) + spikes = [] + for seg_index in range(nseg): + units_dict = units_dict_list[seg_index] + + sample_indices = [] + unit_indices = [] + for u, unit_id in unit_ids: + spike_times = units_dict[unit_id] + sample_indices.append(spike_times) + unit_indices.append(np.full(spike_times.size, u, dtype='int64')) + sample_indices = np.concatenate(sample_indices) + unit_indices = np.concatenate(unit_indices) + + order = np.argsort(sample_indices) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] + + spikes_in_seg = np.zeros(sample_indices.size, dtype=minimum_spike_dtype) + spikes_in_seg['sample_index'] = sample_indices + spikes_in_seg['unit_index'] = unit_indices + spikes_in_seg['segment_index'] = seg_index + spikes.append(spikes_in_seg) + spikes = np.concatenate(spikes) + + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + + return sorting - return sorting @staticmethod def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) -> "NumpySorting": @@ -209,18 +270,20 @@ def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) if unit_ids is None: unit_ids = np.arange(len(neo_spiketrains[0]), dtype="int64") - sorting = NumpySorting(sampling_frequency, unit_ids) + units_dict_list = [] for seg_index in range(nseg): units_dict = {} for u, unit_id in enumerate(unit_ids): st = neo_spiketrains[seg_index][u] units_dict[unit_id] = (st.rescale("s").magnitude * sampling_frequency).astype("int64") - sorting.add_sorting_segment(NumpySortingSegment(units_dict)) + units_dict_list.append(units_dict) + + sorting = NumpySorting.from_unit_dict(units_dict_list, sampling_frequency) return sorting @staticmethod - def from_peaks(peaks, sampling_frequency) -> "NumpySorting": + def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": """ Construct a sorting from peaks returned by 'detect_peaks()' function. The unit ids correspond to the recording channel ids and spike trains are the @@ -238,19 +301,39 @@ def from_peaks(peaks, sampling_frequency) -> "NumpySorting": sorting The NumpySorting object """ - return NumpySorting.from_times_labels(peaks["sample_index"], peaks["channel_index"], sampling_frequency) + spikes = np.zeros(peaks.size, dtype=minimum_spike_dtype) + spikes['sample_index'] = peaks['sample_index'] + spikes['unit_index'] = peaks['channel_index'] + spikes['segment_index'] = peaks['segment_index'] + + if unit_ids is None: + unit_ids = np.unique(peaks['channel_index']) + + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + + return sorting class NumpySortingSegment(BaseSortingSegment): - def __init__(self, units_dict): + def __init__(self, spikes, segment_index, unit_ids): BaseSortingSegment.__init__(self) - for unit_id, times in units_dict.items(): - assert times.dtype.kind == "i", "numpy array of spike times must be integer" - assert np.all(np.diff(times) >= 0), "unsorted times" - self._units_dict = units_dict - + self.spikes = spikes + self.segment_index = segment_index + self.unit_ids = list(unit_ids) + self.spikes_in_seg = None + def get_unit_spike_train(self, unit_id, start_frame, end_frame): - times = self._units_dict[unit_id] + if self.spikes_in_seg is None: + # the slicing of segment is done only once the first time + # this fasten the constructor a lot + s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side='left') + s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side='left') + self.spikes_in_seg = self.spikes[s0:s1] + + unit_index = self.unit_ids.index(unit_id) + + times = self.spikes_in_seg[self.spikes_in_seg['unit_index'] == unit_index] + if start_frame is not None: times = times[times >= start_frame] if end_frame is not None: @@ -258,6 +341,8 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): return times + + class NumpyEvent(BaseEvent): def __init__(self, channel_ids, structured_dtype): BaseEvent.__init__(self, channel_ids, structured_dtype) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6e471121b6..d3c4ed14b2 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -113,7 +113,7 @@ def test_npy_sorting(): "0": np.array([0, 1]), "1": np.array([], dtype="int64"), } - sorting = NumpySorting.from_dict( + sorting = NumpySorting.from_unit_dict( [spike_times_0, spike_times_1], sfreq, ) @@ -144,7 +144,7 @@ def test_npy_sorting(): def test_empty_sorting(): - sorting = NumpySorting.from_dict({}, 30000) + sorting = NumpySorting.from_unit_dict({}, 30000) assert len(sorting.unit_ids) == 0 diff --git a/src/spikeinterface/core/tests/test_frameslicesorting.py b/src/spikeinterface/core/tests/test_frameslicesorting.py index 010d733f6d..e404cfb1be 100644 --- a/src/spikeinterface/core/tests/test_frameslicesorting.py +++ b/src/spikeinterface/core/tests/test_frameslicesorting.py @@ -20,13 +20,13 @@ def test_FrameSliceSorting(): "1": np.arange(min_spike_time, max_spike_time), } # Sorting with attached rec - sorting = NumpySorting.from_dict([spike_times], sf) + sorting = NumpySorting.from_unit_dict([spike_times], sf) rec = NumpyRecording([np.zeros((nsamp, 5))], sampling_frequency=sf) sorting.register_recording(rec) # Sorting without attached rec - sorting_norec = NumpySorting.from_dict([spike_times], sf) + sorting_norec = NumpySorting.from_unit_dict([spike_times], sf) # Sorting with attached rec and exceeding spikes - sorting_exceeding = NumpySorting.from_dict([spike_times], sf) + sorting_exceeding = NumpySorting.from_unit_dict([spike_times], sf) rec_exceeding = NumpyRecording([np.zeros((max_spike_time - 1, 5))], sampling_frequency=sf) with warnings.catch_warnings(): warnings.filterwarnings("ignore") diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 23752699a2..edf4c69798 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -7,6 +7,7 @@ from spikeinterface.core import NumpyRecording, NumpySorting, NumpyEvent from spikeinterface.core import create_sorting_npz from spikeinterface.core import NpzSortingExtractor +from spikeinterface.core.basesorting import minimum_spike_dtype if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "core" @@ -34,7 +35,8 @@ def test_NumpySorting(): # empty unit_ids = [] - sorting = NumpySorting(sampling_frequency, unit_ids) + spikes = np.zeros(0, dtype=minimum_spike_dtype) + sorting = NumpySorting(spikes, sampling_frequency, unit_ids) # print(sorting) # 2 columns diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 5f5695d7f6..e9d0462359 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -467,7 +467,7 @@ def test_empty_sorting(): num_channels = 2 recording = generate_recording(num_channels=num_channels, sampling_frequency=sf, durations=[15.32]) - sorting = NumpySorting.from_dict({}, sf) + sorting = NumpySorting.from_unit_dict({}, sf) folder = cache_folder / "empty_sorting" wvf_extractor = extract_waveforms(recording, sorting, folder, allow_unfiltered=True) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 5fbb827237..de37491883 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1334,7 +1334,7 @@ def run_extract_waveforms(self, seed=None, **job_kwargs): sel = selected_spikes[unit_id][segment_index] selected_spike_times[segment_index][unit_id] = spike_times[sel] - spikes = NumpySorting.from_dict(selected_spike_times, self.sampling_frequency).to_spike_vector() + spikes = NumpySorting.from_unit_dict(selected_spike_times, self.sampling_frequency).to_spike_vector() if self.folder is not None: wf_folder = self.folder / "waveforms" diff --git a/src/spikeinterface/curation/tests/test_curationsorting.py b/src/spikeinterface/curation/tests/test_curationsorting.py index ae625d7e51..e65be2e950 100644 --- a/src/spikeinterface/curation/tests/test_curationsorting.py +++ b/src/spikeinterface/curation/tests/test_curationsorting.py @@ -16,7 +16,7 @@ def test_split_merge(): }, {0: np.arange(15), 1: np.arange(17), 2: np.arange(40, 140), 4: np.arange(40, 140), 5: np.arange(40, 140)}, ] - parent_sort = NumpySorting.from_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms + parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms parent_sort.set_property("someprop", [float(k) for k in spikestimes[0].keys()]) # float # %% @@ -54,7 +54,7 @@ def test_curation(): }, {"a": np.arange(12, 15), "b": np.arange(3, 17), "c": np.arange(50)}, ] - parent_sort = NumpySorting.from_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms + parent_sort = NumpySorting.from_unit_dict(spikestimes, sampling_frequency=1000) # to have 1 sample=1ms parent_sort.set_property("some_names", ["unit_{}".format(k) for k in spikestimes[0].keys()]) # float cs = CurationSorting(parent_sort, properties_policy="remove") # %% diff --git a/src/spikeinterface/postprocessing/tests/test_align_sorting.py b/src/spikeinterface/postprocessing/tests/test_align_sorting.py index f9df45df2a..0adda426a9 100644 --- a/src/spikeinterface/postprocessing/tests/test_align_sorting.py +++ b/src/spikeinterface/postprocessing/tests/test_align_sorting.py @@ -29,7 +29,7 @@ def test_compute_unit_center_of_mass(): # sorting to dict d = {unit_id: sorting.get_unit_spike_train(unit_id) + unit_peak_shifts[unit_id] for unit_id in sorting.unit_ids} - sorting_unaligned = NumpySorting.from_dict(d, sampling_frequency=sorting.get_sampling_frequency()) + sorting_unaligned = NumpySorting.from_unit_dict(d, sampling_frequency=sorting.get_sampling_frequency()) print(sorting_unaligned) sorting_aligned = align_sorting(sorting_unaligned, unit_peak_shifts) diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index bfbac11722..9c3529345b 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -128,7 +128,7 @@ def test_auto_equal_cross_correlograms(): spike_times = np.sort(np.unique(np.random.randint(0, 100000, num_spike))) num_spike = spike_times.size units_dict = {"1": spike_times, "2": spike_times} - sorting = NumpySorting.from_dict([units_dict], sampling_frequency=10000.0) + sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=10000.0) for method in methods: correlograms, bins = compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) @@ -178,7 +178,7 @@ def test_detect_injected_correlation(): spike_times2 = np.sort(spike_times2) units_dict = {"1": spike_times1, "2": spike_times2} - sorting = NumpySorting.from_dict([units_dict], sampling_frequency=sampling_frequency) + sorting = NumpySorting.from_unit_dict([units_dict], sampling_frequency=sampling_frequency) for method in methods: correlograms, bins = compute_correlograms(sorting, window_ms=10.0, bin_ms=0.1, method=method) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 8813cd48bb..5e8db56fea 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -264,7 +264,7 @@ def test_calculate_rp_violations(simulated_data): assert np.allclose(list(rp_contamination_gt.values()), list(rp_contamination.values()), rtol=0.05) np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) - sorting = NumpySorting.from_dict({0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000) + sorting = NumpySorting.from_unit_dict({0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000) we.sorting = sorting rp_contamination, counts = compute_refrac_period_violations(we, 1, 0.0) assert np.isnan(rp_contamination[1]) diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 0ebf04facc..aaa726c3ee 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -138,7 +138,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # convert sorting to new API and save it unit_ids = old_api_sorting.get_unit_ids() units_dict_list = [{u: old_api_sorting.get_unit_spike_train(u) for u in unit_ids}] - new_api_sorting = NumpySorting.from_dict(units_dict_list, samplerate) + new_api_sorting = NumpySorting.from_unit_dict(units_dict_list, samplerate) NpzSortingExtractor.write_sorting(new_api_sorting, str(sorter_output_folder / "firings.npz")) @classmethod From f539c7944fdb1c2564b78faff00da1a2dffd3280 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 11:58:31 +0000 Subject: [PATCH 004/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 1 + src/spikeinterface/core/numpyextractors.py | 61 ++++++++----------- .../tests/test_metrics_functions.py | 4 +- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index d022fb5a07..225d31f76e 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -9,6 +9,7 @@ minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index e9c71b29b7..22fb351fb4 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -14,8 +14,6 @@ from typing import List, Union - - class NumpyRecording(BaseRecording): """ In memory recording. @@ -102,8 +100,8 @@ class NumpySorting(BaseSorting): The internal representation is always done with a long "spike vector". - But we have convinient function to instantiate from other sorting object, from time+labels, - from dict of list or from neo. + But we have convinient function to instantiate from other sorting object, from time+labels, + from dict of list or from neo. Parameters ---------- @@ -114,19 +112,18 @@ class NumpySorting(BaseSorting): channel_ids: list A list of unit_ids. """ + name = "numpy" def __init__(self, spikes, sampling_frequency, unit_ids): - """ - - """ + """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = True if spikes.size == 0: nseg = 0 else: - nseg = spikes[-1]['segment_index'] + 1 + nseg = spikes[-1]["segment_index"] + 1 for segment_index in range(nseg): self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) @@ -139,9 +136,9 @@ def from_extractor(source_sorting: BaseSorting, with_metadata=False) -> "NumpySo Create a numpy sorting from another extractor """ - sorting = NumpySorting(source_sorting.to_spike_vector(), - source_sorting.get_sampling_frequency(), - source_sorting.unit_ids) + sorting = NumpySorting( + source_sorting.to_spike_vector(), source_sorting.get_sampling_frequency(), source_sorting.unit_ids + ) if with_metadata: sorting.copy_metadata(source_sorting) return sorting @@ -177,17 +174,16 @@ def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None if unit_ids is None: unit_ids = np.unique(np.concatenate([np.unique(labels_list[i]) for i in range(nseg)])) - spikes = [] for i in range(nseg): times, labels = times_list[i], labels_list[i] - unit_index = np.zeros(labels.size, dtype='int64') + unit_index = np.zeros(labels.size, dtype="int64") for u, unit_id in enumerate(unit_ids): unit_index[labels == unit_id] = u spikes_in_seg = np.zeros(len(times), dtype=minimum_spike_dtype) - spikes_in_seg['sample_index'] = times - spikes_in_seg['unit_index'] = unit_index - spikes_in_seg['segment_index'] = i + spikes_in_seg["sample_index"] = times + spikes_in_seg["unit_index"] = unit_index + spikes_in_seg["segment_index"] = i order = np.argsort(times) spikes_in_seg = spikes_in_seg[order] spikes.append(spikes_in_seg) @@ -223,25 +219,24 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": for u, unit_id in unit_ids: spike_times = units_dict[unit_id] sample_indices.append(spike_times) - unit_indices.append(np.full(spike_times.size, u, dtype='int64')) + unit_indices.append(np.full(spike_times.size, u, dtype="int64")) sample_indices = np.concatenate(sample_indices) unit_indices = np.concatenate(unit_indices) - + order = np.argsort(sample_indices) sample_indices = sample_indices[order] unit_indices = unit_indices[order] spikes_in_seg = np.zeros(sample_indices.size, dtype=minimum_spike_dtype) - spikes_in_seg['sample_index'] = sample_indices - spikes_in_seg['unit_index'] = unit_indices - spikes_in_seg['segment_index'] = seg_index + spikes_in_seg["sample_index"] = sample_indices + spikes_in_seg["unit_index"] = unit_indices + spikes_in_seg["segment_index"] = seg_index spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - return sorting - + return sorting @staticmethod def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) -> "NumpySorting": @@ -302,12 +297,12 @@ def from_peaks(peaks, sampling_frequency, unit_ids=None) -> "NumpySorting": The NumpySorting object """ spikes = np.zeros(peaks.size, dtype=minimum_spike_dtype) - spikes['sample_index'] = peaks['sample_index'] - spikes['unit_index'] = peaks['channel_index'] - spikes['segment_index'] = peaks['segment_index'] + spikes["sample_index"] = peaks["sample_index"] + spikes["unit_index"] = peaks["channel_index"] + spikes["segment_index"] = peaks["segment_index"] if unit_ids is None: - unit_ids = np.unique(peaks['channel_index']) + unit_ids = np.unique(peaks["channel_index"]) sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -321,18 +316,18 @@ def __init__(self, spikes, segment_index, unit_ids): self.segment_index = segment_index self.unit_ids = list(unit_ids) self.spikes_in_seg = None - + def get_unit_spike_train(self, unit_id, start_frame, end_frame): if self.spikes_in_seg is None: # the slicing of segment is done only once the first time # this fasten the constructor a lot - s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side='left') - s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side='left') + s0 = np.searchsorted(self.spikes["segment_index"], self.segment_index, side="left") + s1 = np.searchsorted(self.spikes["segment_index"], self.segment_index + 1, side="left") self.spikes_in_seg = self.spikes[s0:s1] - + unit_index = self.unit_ids.index(unit_id) - times = self.spikes_in_seg[self.spikes_in_seg['unit_index'] == unit_index] + times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index] if start_frame is not None: times = times[times >= start_frame] @@ -341,8 +336,6 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): return times - - class NumpyEvent(BaseEvent): def __init__(self, channel_ids, structured_dtype): BaseEvent.__init__(self, channel_ids, structured_dtype) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 5e8db56fea..cabd7b847b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -264,7 +264,9 @@ def test_calculate_rp_violations(simulated_data): assert np.allclose(list(rp_contamination_gt.values()), list(rp_contamination.values()), rtol=0.05) np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) - sorting = NumpySorting.from_unit_dict({0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000) + sorting = NumpySorting.from_unit_dict( + {0: np.array([28, 150], dtype=np.int16), 1: np.array([], dtype=np.int16)}, 30000 + ) we.sorting = sorting rp_contamination, counts = compute_refrac_period_violations(we, 1, 0.0) assert np.isnan(rp_contamination[1]) From f81dd7888e05d8468fc20dcf5db30e931fa09012 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 14:11:49 +0200 Subject: [PATCH 005/166] oups --- src/spikeinterface/core/numpyextractors.py | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index e9c71b29b7..1790818489 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -124,7 +124,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): self.is_dumpable = True if spikes.size == 0: - nseg = 0 + nseg = 1 else: nseg = spikes[-1]['segment_index'] + 1 @@ -220,18 +220,19 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": sample_indices = [] unit_indices = [] - for u, unit_id in unit_ids: + for u, unit_id in enumerate(unit_ids): spike_times = units_dict[unit_id] sample_indices.append(spike_times) unit_indices.append(np.full(spike_times.size, u, dtype='int64')) - sample_indices = np.concatenate(sample_indices) - unit_indices = np.concatenate(unit_indices) - - order = np.argsort(sample_indices) - sample_indices = sample_indices[order] - unit_indices = unit_indices[order] - - spikes_in_seg = np.zeros(sample_indices.size, dtype=minimum_spike_dtype) + if len(sample_indices) > 0: + sample_indices = np.concatenate(sample_indices) + unit_indices = np.concatenate(unit_indices) + + order = np.argsort(sample_indices) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] + + spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) spikes_in_seg['sample_index'] = sample_indices spikes_in_seg['unit_index'] = unit_indices spikes_in_seg['segment_index'] = seg_index @@ -332,7 +333,8 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): unit_index = self.unit_ids.index(unit_id) - times = self.spikes_in_seg[self.spikes_in_seg['unit_index'] == unit_index] + times = self.spikes_in_seg[self.spikes_in_seg['unit_index'] == unit_index]['sample_index'] + if start_frame is not None: times = times[times >= start_frame] From 1ef08afa70a912662a0b022d7ee5f6de7a325df1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 12:20:38 +0000 Subject: [PATCH 006/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/numpyextractors.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 74a9470eaf..e4094af9d6 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -220,19 +220,19 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": spike_times = units_dict[unit_id] sample_indices.append(spike_times) - unit_indices.append(np.full(spike_times.size, u, dtype='int64')) + unit_indices.append(np.full(spike_times.size, u, dtype="int64")) if len(sample_indices) > 0: sample_indices = np.concatenate(sample_indices) unit_indices = np.concatenate(unit_indices) - + order = np.argsort(sample_indices) sample_indices = sample_indices[order] unit_indices = unit_indices[order] spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) - spikes_in_seg['sample_index'] = sample_indices - spikes_in_seg['unit_index'] = unit_indices - spikes_in_seg['segment_index'] = seg_index + spikes_in_seg["sample_index"] = sample_indices + spikes_in_seg["unit_index"] = unit_indices + spikes_in_seg["segment_index"] = seg_index spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) @@ -329,9 +329,8 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): unit_index = self.unit_ids.index(unit_id) + times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"] - times = self.spikes_in_seg[self.spikes_in_seg['unit_index'] == unit_index]['sample_index'] - if start_frame is not None: times = times[times >= start_frame] if end_frame is not None: From 06c4e3e93b0e72681d4ffdd6892360d0a1c258b8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 14:28:09 +0200 Subject: [PATCH 007/166] Sorting.from_extractor() > Sorting.from_sorting() --- src/spikeinterface/core/basesorting.py | 2 +- src/spikeinterface/core/numpyextractors.py | 4 ++-- src/spikeinterface/core/tests/test_numpy_extractors.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 225d31f76e..2eb48387f5 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -217,7 +217,7 @@ def _save(self, format="npz", **save_kwargs): elif format == "memory": from .numpyextractors import NumpySorting - cached = NumpySorting.from_extractor(self) + cached = NumpySorting.from_sorting(self) else: raise ValueError(f"format {format} not supported") return cached diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 74a9470eaf..5558a7c757 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -131,9 +131,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): self._kwargs = dict(spikes=spikes, sampling_frequency=sampling_frequency, unit_ids=unit_ids) @staticmethod - def from_extractor(source_sorting: BaseSorting, with_metadata=False) -> "NumpySorting": + def from_sorting(source_sorting: BaseSorting, with_metadata=False) -> "NumpySorting": """ - Create a numpy sorting from another extractor + Create a numpy sorting from another sorting extractor """ sorting = NumpySorting( diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index edf4c69798..9970b8b8b0 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -59,7 +59,7 @@ def test_NumpySorting(): create_sorting_npz(num_seg, file_path) other_sorting = NpzSortingExtractor(file_path) - sorting = NumpySorting.from_extractor(other_sorting) + sorting = NumpySorting.from_sorting(other_sorting) # print(sorting) From 84241ea66451686b66e0cd21e74fcae9b6d08a85 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 16:07:28 +0200 Subject: [PATCH 008/166] Deprecate Sorting.get_all_spike_trains() in favor of Sorting.to_spike_vector() --- src/spikeinterface/core/basesorting.py | 58 +++++++++++++------ src/spikeinterface/core/snippets_tools.py | 4 +- .../core/tests/test_basesnippets.py | 2 +- .../core/tests/test_basesorting.py | 12 ++-- src/spikeinterface/exporters/to_phy.py | 5 +- .../postprocessing/correlograms.py | 11 ++-- src/spikeinterface/postprocessing/isi.py | 9 ++- .../postprocessing/principal_component.py | 5 +- .../postprocessing/spike_amplitudes.py | 25 +++++--- .../tests/test_principal_component.py | 9 ++- .../qualitymetrics/misc_metrics.py | 8 +-- .../benchmark/benchmark_clustering.py | 17 +++--- .../benchmark/benchmark_peak_localization.py | 10 ++-- .../benchmark/benchmark_peak_selection.py | 42 +++++++------- .../tests/test_peak_detection.py | 2 +- 15 files changed, 129 insertions(+), 90 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2eb48387f5..3712306c28 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -313,6 +313,13 @@ def get_all_spike_trains(self, outputs="unit_id"): """ Return all spike trains concatenated """ + + warnings.warn( + "Sorting.get_all_spike_trains() will be deprecated. Sorting.to_spike_vector() instead", + DeprecationWarning, + stacklevel=2, + ) + assert outputs in ("unit_id", "unit_index") spikes = [] for segment_index in range(self.get_num_segments()): @@ -339,7 +346,7 @@ def get_all_spike_trains(self, outputs="unit_id"): spikes.append((spike_times, spike_labels)) return spikes - def to_spike_vector(self, extremum_channel_inds=None): + def to_spike_vector(self, concatenated=True, extremum_channel_inds=None): """ Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. @@ -348,6 +355,9 @@ def to_spike_vector(self, extremum_channel_inds=None): Parameters ---------- + concatenated: bool + By default the output is one numpy vector. + With concatenated=False then it is a list of vector by segment. extremum_channel_inds: None or dict If a dictionnary of unit_id to channel_ind is given then an extra field 'channel_index'. This can be convinient for computing spikes postion after sorter. @@ -362,28 +372,40 @@ def to_spike_vector(self, extremum_channel_inds=None): is given """ - spikes_ = self.get_all_spike_trains(outputs="unit_index") - - n = np.sum([e[0].size for e in spikes_]) spike_dtype = minimum_spike_dtype - if extremum_channel_inds is not None: spike_dtype += [("channel_index", "int64")] - spikes = np.zeros(n, dtype=spike_dtype) - - pos = 0 - for segment_index, (spike_times, spike_labels) in enumerate(spikes_): - n = spike_times.size - spikes[pos : pos + n]["sample_index"] = spike_times - spikes[pos : pos + n]["unit_index"] = spike_labels - spikes[pos : pos + n]["segment_index"] = segment_index - pos += n + spikes = [] + for segment_index in range(self.get_num_segments()): + sample_indices = [] + unit_indices = [] + for u, unit_id in enumerate(self.unit_ids): + spike_times = st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + sample_indices.append(spike_times) + unit_indices.append(np.full(spike_times.size, u, dtype="int64")) + + if len(sample_indices) > 0: + sample_indices = np.concatenate(sample_indices, dtype='int64') + unit_indices = np.concatenate(unit_indices, dtype='int64') + order = np.argsort(sample_indices) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] + + spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) + spikes_in_seg["sample_index"] = sample_indices + spikes_in_seg["unit_index"] = unit_indices + spikes_in_seg["segment_index"] = segment_index + spikes.append(spikes_in_seg) + + if extremum_channel_inds is not None: + ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) + # vector way + spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] + + if concatenated: + spikes = np.concatenate(spikes) - if extremum_channel_inds is not None: - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) - # vector way - spikes["channel_index"] = ext_channel_inds[spikes["unit_index"]] return spikes diff --git a/src/spikeinterface/core/snippets_tools.py b/src/spikeinterface/core/snippets_tools.py index a88056b8b1..454d3622f3 100644 --- a/src/spikeinterface/core/snippets_tools.py +++ b/src/spikeinterface/core/snippets_tools.py @@ -26,7 +26,7 @@ def snippets_from_sorting(recording, sorting, nbefore=20, nafter=44, wf_folder=N Snippets extractor created """ job_kwargs = fix_job_kwargs(job_kwargs) - strains = sorting.get_all_spike_trains() + spikes = sorting.to_spike_vector(concatenated=False) peaks2 = sorting.to_spike_vector() peaks2["unit_index"] = 0 @@ -58,7 +58,7 @@ def snippets_from_sorting(recording, sorting, nbefore=20, nafter=44, wf_folder=N nse = NumpySnippets( snippets_list=wfs, - spikesframes_list=[np.sort(s[0]) for s in strains], + spikesframes_list=[s['sample_index'] for s in spikes], sampling_frequency=recording.get_sampling_frequency(), nbefore=nbefore, channel_ids=recording.get_channel_ids(), diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index d286a0dd37..3fd5091486 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -87,7 +87,7 @@ def test_BaseSnippets(): times0 = snippets.get_frames(segment_index=0) - seg0_times = sorting.get_all_spike_trains()[0][0] + seg0_times = sorting.to_spike_vector(concatenated=False)[0]['sample_index'] assert np.array_equal(seg0_times, times0) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index d3c4ed14b2..f5307d3a28 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -81,7 +81,8 @@ def test_BaseSorting(): sorting4 = sorting.save(format="memory") check_sortings_equal(sorting, sorting4, check_annotations=True, check_properties=True) - spikes = sorting.get_all_spike_trains() + with pytest.warns(DeprecationWarning): + spikes = sorting.get_all_spike_trains() # print(spikes) spikes = sorting.to_spike_vector() @@ -148,10 +149,11 @@ def test_empty_sorting(): assert len(sorting.unit_ids) == 0 - spikes = sorting.get_all_spike_trains() - assert len(spikes) == 1 - assert len(spikes[0][0]) == 0 - assert len(spikes[0][1]) == 0 + with pytest.warns(DeprecationWarning): + spikes = sorting.get_all_spike_trains() + assert len(spikes) == 1 + assert len(spikes[0][0]) == 0 + assert len(spikes[0][1]) == 0 spikes = sorting.to_spike_vector() assert spikes.shape == (0,) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 2b50028aa9..6a3ecc205d 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -147,8 +147,9 @@ def export_to_phy( # export spike_times/spike_templates/spike_clusters # here spike_labels is a remapping to unit_index - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") - spike_times, spike_labels = all_spikes[0] + all_spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] + spike_times = all_spikes_seg0["sample_index"] + spike_labels = all_spikes_seg0["unit_index"] np.save(str(output_folder / "spike_times.npy"), spike_times[:, np.newaxis]) np.save(str(output_folder / "spike_templates.npy"), spike_labels[:, np.newaxis]) np.save(str(output_folder / "spike_clusters.npy"), spike_labels[:, np.newaxis]) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 39118e6304..d6e074fc2c 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -216,7 +216,7 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): """ num_seg = sorting.get_num_segments() num_units = len(sorting.unit_ids) - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=False) num_half_bins = int(window_size // bin_size) num_bins = int(2 * num_half_bins) @@ -224,7 +224,8 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64") for seg_index in range(num_seg): - spike_times, spike_labels = spikes[seg_index] + spike_times = spikes[seg_index]['sample_index'] + spike_labels = spikes[seg_index]['unit_index'] c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size) @@ -305,11 +306,13 @@ def compute_correlograms_numba(sorting, window_size, bin_size): num_bins = 2 * int(window_size / bin_size) num_units = len(sorting.unit_ids) - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=false) correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): - spike_times, spike_labels = spikes[seg_index] + spike_times = spikes[seg_index]['sample_index'] + spike_labels = spikes[seg_index]['unit_index'] + _compute_correlograms_numba( correlograms, spike_times.astype(np.int64), spike_labels.astype(np.int32), window_size, bin_size ) diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 678ce8c2fd..eac10fa763 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -233,15 +233,18 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float assert num_bins >= 1 bins = np.arange(0, window_size + bin_size, bin_size) * 1e3 / fs - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=False) ISIs = np.zeros((num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): + spike_times = spikes[seg_index]['sample_index'].astype(np.int64) + spike_labels = spikes[seg_index]['unit_index'].astype(np.int32) + _compute_isi_histograms_numba( ISIs, - spikes[seg_index][0].astype(np.int64), - spikes[seg_index][1].astype(np.int32), + spike_times, + spike_labels, window_size, bin_size, fs, diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f6bdf5c5e1..84cbeb9696 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -307,8 +307,9 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): file_path = self.extension_folder / "all_pcs.npy" file_path = Path(file_path) - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") - spike_times, spike_labels = all_spikes[0] + spikes = sorting.to_spike_vector(concatenated=False) + spike_times = spikes['sample_index'] + spike_labels = spikes['unit_index'] sparsity = self.get_sparsity() if sparsity is None: diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index c8bab06fbf..77adb0536f 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -26,12 +26,15 @@ def _set_params(self, peak_sign="neg", return_scaled=True): def _select_extension_data(self, unit_ids): # load filter and save amplitude files + sorting = self.waveform_extractor.sorting + spikes = sorting.to_spike_vector(concatenated=False) + keep_unit_indices, = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + new_extension_data = dict() - for seg_index in range(self.waveform_extractor.recording.get_num_segments()): + for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - _, all_labels = self.waveform_extractor.sorting.get_all_spike_trains()[seg_index] - filtered_idxs = np.in1d(all_labels, np.array(unit_ids)).nonzero() + filtered_idxs = np.in1d(spikes[seg_index]['unit_index'], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data @@ -45,7 +48,7 @@ def _run(self, **job_kwargs): recording = we.recording sorting = we.sorting - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") + all_spikes = sorting.to_spike_vector() self._all_spikes = all_spikes peak_sign = self._params["peak_sign"] @@ -107,7 +110,7 @@ def get_data(self, outputs="concatenated"): """ we = self.waveform_extractor sorting = we.sorting - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") + if outputs == "concatenated": amplitudes = [] @@ -115,11 +118,13 @@ def get_data(self, outputs="concatenated"): amplitudes.append(self._extension_data[f"amplitude_segment_{segment_index}"]) return amplitudes elif outputs == "by_unit": + all_spikes = sorting.to_spike_vector(concatenated=False) + amplitudes_by_unit = [] for segment_index in range(we.get_num_segments()): amplitudes_by_unit.append({}) for unit_index, unit_id in enumerate(sorting.unit_ids): - _, spike_labels = all_spikes[segment_index] + spike_labels = all_spikes[segment_index]['unit_index'] mask = spike_labels == unit_index amps = self._extension_data[f"amplitude_segment_{segment_index}"][mask] amplitudes_by_unit[segment_index][unit_id] = amps @@ -193,8 +198,8 @@ def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, p worker_ctx["min_shift"] = np.min(peak_shifts) worker_ctx["max_shifts"] = np.max(peak_shifts) - all_spikes = sorting.get_all_spike_trains(outputs="unit_index") - worker_ctx["all_spikes"] = all_spikes + + worker_ctx["all_spikes"] = sorting.to_spike_vector(concatenated=False) worker_ctx["extremum_channels_index"] = extremum_channels_index return worker_ctx @@ -209,7 +214,9 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - spike_times, spike_labels = all_spikes[segment_index] + spike_times = all_spikes[segment_index]['sample_index'] + spike_labels = all_spikes[segment_index]['unit_index'] + d = np.diff(spike_times) assert np.all(d >= 0) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 870e710877..90d6a2f3f8 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -64,11 +64,10 @@ def test_compute_for_all_spikes(self): pc_file_sparse = pc.extension_folder / "all_pc_sparse.npy" pc_sparse.run_for_all_spikes(pc_file_sparse, chunk_size=10000, n_jobs=1) all_pc_sparse = np.load(pc_file_sparse) - all_spikes = we_copy.sorting.get_all_spike_trains(outputs="unit_id") - _, spike_labels = all_spikes[0] - for unit_id, sparse_channel_ids in sparsity.unit_id_to_channel_ids.items(): - # check dimensions - pc_unit = all_pc_sparse[spike_labels == unit_id] + all_spikes_seg0 = we_copy.sorting.to_spike_vector(concatenated=False)[0] + for unit_index, unit_id in enumerate(we.unit_ids): + sparse_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] + pc_unit = all_pc_sparse[all_spikes_seg0['unit_index'] == unit_index] assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) def test_sparse(self): diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 42f6b8b0da..42eb3e6677 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -342,7 +342,7 @@ def compute_refrac_period_violations( fs = sorting.get_sampling_frequency() num_units = len(sorting.unit_ids) num_segments = sorting.get_num_segments() - spikes = sorting.get_all_spike_trains(outputs="unit_index") + spikes = sorting.to_spike_vector(concatenated=False) num_spikes = compute_num_spikes(waveform_extractor) t_c = int(round(censored_period_ms * fs * 1e-3)) @@ -350,9 +350,9 @@ def compute_refrac_period_violations( nb_rp_violations = np.zeros((num_units), dtype=np.int64) for seg_index in range(num_segments): - _compute_rp_violations_numba( - nb_rp_violations, spikes[seg_index][0].astype(np.int64), spikes[seg_index][1].astype(np.int32), t_c, t_r - ) + spike_times = spikes[seg_index]['sample_index'].astype(np.int64) + spike_labels = spikes[seg_index]['unit_index'].astype(np.int32) + _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) T = waveform_extractor.get_total_samples() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 7eea75ce81..da416ba8f4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -161,15 +161,15 @@ def run(self, peaks=None, positions=None, method=None, method_kwargs={}, delta=0 if self.verbose: print("Performing the comparison with (sliced) ground truth") - times1 = self.gt_sorting.get_all_spike_trains()[0] - times2 = self.clustering.get_all_spike_trains()[0] - matches = make_matching_events(times1[0], times2[0], int(delta * self.sampling_rate / 1000)) + spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0] + spikes2 = self.clustering.to_spike_vector(concatenated=False)[0] + + matches = make_matching_events(spikes1['sample_index'], spikes2['sample_index'], + int(delta * self.sampling_rate / 1000)) self.matches = matches idx = matches["index1"] - self.sliced_gt_sorting = NumpySorting.from_times_labels( - times1[0][idx], times1[1][idx], self.sampling_rate, unit_ids=self.gt_sorting.unit_ids - ) + self.sliced_gt_sorting = NumpySorting(spikes1[idx], self.sampling_rate, self.gt_sorting.unit_ids) self.comp = GroundTruthComparison(self.sliced_gt_sorting, self.clustering, exhaustive_gt=self.exhaustive_gt) @@ -251,10 +251,11 @@ def _scatter_clusters( # scatter and collect gaussian info means = {} covs = {} - labels_ids = sorting.get_all_spike_trains()[0][1] + labels = sorting.to_spike_vector(concatenated=False)[0]['unit_index'] for unit_ind, unit_id in enumerate(sorting.unit_ids): - where = np.flatnonzero(labels_ids == unit_id) + where = np.flatnonzero(labels == unit_ind) + xk = xs[where] yk = ys[where] diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index b5ad24a5b3..84b8db7892 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -447,12 +447,12 @@ def plot_figure_1(benchmark, mode="average", cell_ind="auto"): import spikeinterface.full as si - unit_id = benchmark.waveforms.sorting.unit_ids[cell_ind] + sorting = benchmark.waveforms.sorting + unit_id = sorting.unit_ids[cell_ind] - mask = benchmark.waveforms.sorting.get_all_spike_trains()[0][1] == unit_id - times = ( - benchmark.waveforms.sorting.get_all_spike_trains()[0][0][mask] / benchmark.recording.get_sampling_frequency() - ) + spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] + mask = spikes_seg0['unit_index'] == cell_ind + times = spikes_seg0[mask] / sorting.get_sampling_frequency() print(benchmark.recording) # si.plot_timeseries(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index b82102e9fd..6c0f350ba8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -113,26 +113,24 @@ def run(self, peaks=None, positions=None, delta=0.2): if positions is not None: self._positions = positions - times1 = self.gt_sorting.get_all_spike_trains()[0] + spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0]["sample_index"] times2 = self.peaks["sample_index"] print("The gt recording has {} peaks and {} have been detected".format(len(times1[0]), len(times2))) - matches = make_matching_events(times1[0], times2, int(delta * self.sampling_rate / 1000)) + matches = make_matching_events(spikes1['sample_index'], times2, int(delta * self.sampling_rate / 1000)) self.matches = matches self.deltas = {"labels": [], "delta": matches["delta_frame"]} - self.deltas["labels"] = times1[1][matches["index1"]] + self.deltas["labels"] = spikes1['unit_index'][matches["index1"]] - # print(len(times1[0]), len(matches['index1'])) gt_matches = matches["index1"] - self.sliced_gt_sorting = NumpySorting.from_times_labels( - times1[0][gt_matches], times1[1][gt_matches], self.sampling_rate, unit_ids=self.gt_sorting.unit_ids - ) - ratio = 100 * len(gt_matches) / len(times1[0]) + self.sliced_gt_sorting = NumpySorting(spikes1[gt_matches], self.sampling_rate, self.gt_sorting.unit_ids) + + ratio = 100 * len(gt_matches) / len(spikes1) print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) - matches = make_matching_events(times2, times1[0], int(delta * self.sampling_rate / 1000)) + matches = make_matching_events(times2, spikes1['sample_index'], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) @@ -231,10 +229,11 @@ def _scatter_clusters( # scatter and collect gaussian info means = {} covs = {} - labels_ids = sorting.get_all_spike_trains()[0][1] + labels = sorting.to_spike_vector(concatenated=False)[0]['unit_index'] for unit_ind, unit_id in enumerate(sorting.unit_ids): - where = np.flatnonzero(labels_ids == unit_id) + where = np.flatnonzero(labels == unit_ind) + xk = xs[where] yk = ys[where] @@ -539,11 +538,11 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) nb_spikes += [b] centers = compute_center_of_mass(self.waveforms["gt"]) - times, labels = self.sliced_gt_sorting.get_all_spike_trains()[0] + spikes_seg0 = self.sliced_gt_sorting.to_spike_vector(concatenated=False)[0] stds = [] means = [] - for found, real in zip(unit_ids2, inds_1): - mask = labels == found + for found, real in zip(inds_2, inds_1): + mask = spikes_seg0['unit_index'] == found center = np.array([self.sliced_gt_positions[mask]["x"], self.sliced_gt_positions[mask]["y"]]).mean() means += [np.mean(center - centers[real])] stds += [np.std(center - centers[real])] @@ -613,22 +612,23 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) def explore_garbage(self, channel_index, nb_bins=None, dt=None): mask = self.garbage_peaks["channel_index"] == channel_index times2 = self.garbage_peaks[mask]["sample_index"] - times1 = self.gt_sorting.get_all_spike_trains()[0] + spikes1 = self.gt_sorting.to_spike_vector(concatenate=False)[0] + from spikeinterface.comparison.comparisontools import make_matching_events if dt is None: delta = self.waveforms["garbage"].nafter else: delta = dt - matches = make_matching_events(times2, times1[0], delta) - units = times1[1][matches["index2"]] + matches = make_matching_events(times2, spikes1['sample_index'], delta) + unit_inds = spikes1['unit_index'][matches["index2"]] dt = matches["delta_frame"] res = {} fig, ax = plt.subplots() if nb_bins is None: nb_bins = 2 * delta xaxis = np.linspace(-delta, delta, nb_bins) - for unit_id in np.unique(units): - mask = units == unit_id - res[unit_id] = dt[mask] - ax.hist(res[unit_id], bins=xaxis) + for unit_ind in np.unique(unit_inds): + mask = unit_inds == unit_ind + res[unit_ind] = dt[mask] + ax.hist(res[unit_ind], bins=xaxis) diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 380bd67a94..a21ffd0335 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -73,7 +73,7 @@ def sorting_fixture(): def spike_trains(sorting): - spike_trains = sorting.get_all_spike_trains()[0][0] + spike_trains = sorting.to_spike_vector()['sample_index'] return spike_trains From 5835b45f962f8ddc53dda97cd25df0877efa69b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 14:08:11 +0000 Subject: [PATCH 009/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 5 ++--- src/spikeinterface/core/snippets_tools.py | 2 +- src/spikeinterface/core/tests/test_basesnippets.py | 2 +- src/spikeinterface/postprocessing/correlograms.py | 8 ++++---- src/spikeinterface/postprocessing/isi.py | 4 ++-- .../postprocessing/principal_component.py | 4 ++-- .../postprocessing/spike_amplitudes.py | 12 +++++------- .../tests/test_principal_component.py | 2 +- src/spikeinterface/qualitymetrics/misc_metrics.py | 4 ++-- .../benchmark/benchmark_clustering.py | 7 ++++--- .../benchmark/benchmark_peak_localization.py | 2 +- .../benchmark/benchmark_peak_selection.py | 14 +++++++------- .../sortingcomponents/tests/test_peak_detection.py | 2 +- 13 files changed, 33 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 3712306c28..b8cc35c9c8 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -386,8 +386,8 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None): unit_indices.append(np.full(spike_times.size, u, dtype="int64")) if len(sample_indices) > 0: - sample_indices = np.concatenate(sample_indices, dtype='int64') - unit_indices = np.concatenate(unit_indices, dtype='int64') + sample_indices = np.concatenate(sample_indices, dtype="int64") + unit_indices = np.concatenate(unit_indices, dtype="int64") order = np.argsort(sample_indices) sample_indices = sample_indices[order] unit_indices = unit_indices[order] @@ -406,7 +406,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None): if concatenated: spikes = np.concatenate(spikes) - return spikes diff --git a/src/spikeinterface/core/snippets_tools.py b/src/spikeinterface/core/snippets_tools.py index 454d3622f3..7f342ef604 100644 --- a/src/spikeinterface/core/snippets_tools.py +++ b/src/spikeinterface/core/snippets_tools.py @@ -58,7 +58,7 @@ def snippets_from_sorting(recording, sorting, nbefore=20, nafter=44, wf_folder=N nse = NumpySnippets( snippets_list=wfs, - spikesframes_list=[s['sample_index'] for s in spikes], + spikesframes_list=[s["sample_index"] for s in spikes], sampling_frequency=recording.get_sampling_frequency(), nbefore=nbefore, channel_ids=recording.get_channel_ids(), diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 3fd5091486..544f6315df 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -87,7 +87,7 @@ def test_BaseSnippets(): times0 = snippets.get_frames(segment_index=0) - seg0_times = sorting.to_spike_vector(concatenated=False)[0]['sample_index'] + seg0_times = sorting.to_spike_vector(concatenated=False)[0]["sample_index"] assert np.array_equal(seg0_times, times0) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index d6e074fc2c..3cb8e9b96b 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -224,8 +224,8 @@ def compute_correlograms_numpy(sorting, window_size, bin_size): correlograms = np.zeros((num_units, num_units, num_bins), dtype="int64") for seg_index in range(num_seg): - spike_times = spikes[seg_index]['sample_index'] - spike_labels = spikes[seg_index]['unit_index'] + spike_times = spikes[seg_index]["sample_index"] + spike_labels = spikes[seg_index]["unit_index"] c0 = correlogram_for_one_segment(spike_times, spike_labels, window_size, bin_size) @@ -310,8 +310,8 @@ def compute_correlograms_numba(sorting, window_size, bin_size): correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): - spike_times = spikes[seg_index]['sample_index'] - spike_labels = spikes[seg_index]['unit_index'] + spike_times = spikes[seg_index]["sample_index"] + spike_labels = spikes[seg_index]["unit_index"] _compute_correlograms_numba( correlograms, spike_times.astype(np.int64), spike_labels.astype(np.int32), window_size, bin_size diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index eac10fa763..aec70141cf 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -238,8 +238,8 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float ISIs = np.zeros((num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): - spike_times = spikes[seg_index]['sample_index'].astype(np.int64) - spike_labels = spikes[seg_index]['unit_index'].astype(np.int32) + spike_times = spikes[seg_index]["sample_index"].astype(np.int64) + spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_isi_histograms_numba( ISIs, diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 84cbeb9696..48d4f49ebe 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -308,8 +308,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): file_path = Path(file_path) spikes = sorting.to_spike_vector(concatenated=False) - spike_times = spikes['sample_index'] - spike_labels = spikes['unit_index'] + spike_times = spikes["sample_index"] + spike_labels = spikes["unit_index"] sparsity = self.get_sparsity() if sparsity is None: diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 77adb0536f..6790e6d113 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -28,13 +28,13 @@ def _select_extension_data(self, unit_ids): # load filter and save amplitude files sorting = self.waveform_extractor.sorting spikes = sorting.to_spike_vector(concatenated=False) - keep_unit_indices, = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) + (keep_unit_indices,) = np.nonzero(np.in1d(sorting.unit_ids, unit_ids)) new_extension_data = dict() for seg_index in range(sorting.get_num_segments()): amp_data_name = f"amplitude_segment_{seg_index}" amps = self._extension_data[amp_data_name] - filtered_idxs = np.in1d(spikes[seg_index]['unit_index'], keep_unit_indices) + filtered_idxs = np.in1d(spikes[seg_index]["unit_index"], keep_unit_indices) new_extension_data[amp_data_name] = amps[filtered_idxs] return new_extension_data @@ -110,7 +110,6 @@ def get_data(self, outputs="concatenated"): """ we = self.waveform_extractor sorting = we.sorting - if outputs == "concatenated": amplitudes = [] @@ -124,7 +123,7 @@ def get_data(self, outputs="concatenated"): for segment_index in range(we.get_num_segments()): amplitudes_by_unit.append({}) for unit_index, unit_id in enumerate(sorting.unit_ids): - spike_labels = all_spikes[segment_index]['unit_index'] + spike_labels = all_spikes[segment_index]["unit_index"] mask = spike_labels == unit_index amps = self._extension_data[f"amplitude_segment_{segment_index}"][mask] amplitudes_by_unit[segment_index][unit_id] = amps @@ -198,7 +197,6 @@ def _init_worker_spike_amplitudes(recording, sorting, extremum_channels_index, p worker_ctx["min_shift"] = np.min(peak_shifts) worker_ctx["max_shifts"] = np.max(peak_shifts) - worker_ctx["all_spikes"] = sorting.to_spike_vector(concatenated=False) worker_ctx["extremum_channels_index"] = extremum_channels_index @@ -214,8 +212,8 @@ def _spike_amplitudes_chunk(segment_index, start_frame, end_frame, worker_ctx): seg_size = recording.get_num_samples(segment_index=segment_index) - spike_times = all_spikes[segment_index]['sample_index'] - spike_labels = all_spikes[segment_index]['unit_index'] + spike_times = all_spikes[segment_index]["sample_index"] + spike_labels = all_spikes[segment_index]["unit_index"] d = np.diff(spike_times) assert np.all(d >= 0) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 90d6a2f3f8..b73477d306 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -67,7 +67,7 @@ def test_compute_for_all_spikes(self): all_spikes_seg0 = we_copy.sorting.to_spike_vector(concatenated=False)[0] for unit_index, unit_id in enumerate(we.unit_ids): sparse_channel_ids = sparsity.unit_id_to_channel_ids[unit_id] - pc_unit = all_pc_sparse[all_spikes_seg0['unit_index'] == unit_index] + pc_unit = all_pc_sparse[all_spikes_seg0["unit_index"] == unit_index] assert np.allclose(pc_unit[:, :, len(sparse_channel_ids) :], 0) def test_sparse(self): diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 42eb3e6677..66ccd60b77 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -350,8 +350,8 @@ def compute_refrac_period_violations( nb_rp_violations = np.zeros((num_units), dtype=np.int64) for seg_index in range(num_segments): - spike_times = spikes[seg_index]['sample_index'].astype(np.int64) - spike_labels = spikes[seg_index]['unit_index'].astype(np.int32) + spike_times = spikes[seg_index]["sample_index"].astype(np.int64) + spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) T = waveform_extractor.get_total_samples() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index da416ba8f4..d68b8e5449 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -164,8 +164,9 @@ def run(self, peaks=None, positions=None, method=None, method_kwargs={}, delta=0 spikes1 = self.gt_sorting.to_spike_vector(concatenated=False)[0] spikes2 = self.clustering.to_spike_vector(concatenated=False)[0] - matches = make_matching_events(spikes1['sample_index'], spikes2['sample_index'], - int(delta * self.sampling_rate / 1000)) + matches = make_matching_events( + spikes1["sample_index"], spikes2["sample_index"], int(delta * self.sampling_rate / 1000) + ) self.matches = matches idx = matches["index1"] @@ -251,7 +252,7 @@ def _scatter_clusters( # scatter and collect gaussian info means = {} covs = {} - labels = sorting.to_spike_vector(concatenated=False)[0]['unit_index'] + labels = sorting.to_spike_vector(concatenated=False)[0]["unit_index"] for unit_ind, unit_id in enumerate(sorting.unit_ids): where = np.flatnonzero(labels == unit_ind) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index 84b8db7892..3132de71ae 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -451,7 +451,7 @@ def plot_figure_1(benchmark, mode="average", cell_ind="auto"): unit_id = sorting.unit_ids[cell_ind] spikes_seg0 = sorting.to_spike_vector(concatenated=False)[0] - mask = spikes_seg0['unit_index'] == cell_ind + mask = spikes_seg0["unit_index"] == cell_ind times = spikes_seg0[mask] / sorting.get_sampling_frequency() print(benchmark.recording) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py index 6c0f350ba8..1514a63dd4 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_selection.py @@ -118,11 +118,11 @@ def run(self, peaks=None, positions=None, delta=0.2): print("The gt recording has {} peaks and {} have been detected".format(len(times1[0]), len(times2))) - matches = make_matching_events(spikes1['sample_index'], times2, int(delta * self.sampling_rate / 1000)) + matches = make_matching_events(spikes1["sample_index"], times2, int(delta * self.sampling_rate / 1000)) self.matches = matches self.deltas = {"labels": [], "delta": matches["delta_frame"]} - self.deltas["labels"] = spikes1['unit_index'][matches["index1"]] + self.deltas["labels"] = spikes1["unit_index"][matches["index1"]] gt_matches = matches["index1"] self.sliced_gt_sorting = NumpySorting(spikes1[gt_matches], self.sampling_rate, self.gt_sorting.unit_ids) @@ -130,7 +130,7 @@ def run(self, peaks=None, positions=None, delta=0.2): ratio = 100 * len(gt_matches) / len(spikes1) print("Only {0:.2f}% of gt peaks are matched to detected peaks".format(ratio)) - matches = make_matching_events(times2, spikes1['sample_index'], int(delta * self.sampling_rate / 1000)) + matches = make_matching_events(times2, spikes1["sample_index"], int(delta * self.sampling_rate / 1000)) self.good_matches = matches["index1"] garbage_matches = ~np.in1d(np.arange(len(times2)), self.good_matches) @@ -229,7 +229,7 @@ def _scatter_clusters( # scatter and collect gaussian info means = {} covs = {} - labels = sorting.to_spike_vector(concatenated=False)[0]['unit_index'] + labels = sorting.to_spike_vector(concatenated=False)[0]["unit_index"] for unit_ind, unit_id in enumerate(sorting.unit_ids): where = np.flatnonzero(labels == unit_ind) @@ -542,7 +542,7 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5) stds = [] means = [] for found, real in zip(inds_2, inds_1): - mask = spikes_seg0['unit_index'] == found + mask = spikes_seg0["unit_index"] == found center = np.array([self.sliced_gt_positions[mask]["x"], self.sliced_gt_positions[mask]["y"]]).mean() means += [np.mean(center - centers[real])] stds += [np.std(center - centers[real])] @@ -620,8 +620,8 @@ def explore_garbage(self, channel_index, nb_bins=None, dt=None): delta = self.waveforms["garbage"].nafter else: delta = dt - matches = make_matching_events(times2, spikes1['sample_index'], delta) - unit_inds = spikes1['unit_index'][matches["index2"]] + matches = make_matching_events(times2, spikes1["sample_index"], delta) + unit_inds = spikes1["unit_index"][matches["index2"]] dt = matches["delta_frame"] res = {} fig, ax = plt.subplots() diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index a21ffd0335..b203399d18 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -73,7 +73,7 @@ def sorting_fixture(): def spike_trains(sorting): - spike_trains = sorting.to_spike_vector()['sample_index'] + spike_trains = sorting.to_spike_vector()["sample_index"] return spike_trains From 1e6d56fbe9fbcb24957c7bd9d822d8772688b76c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 16:53:15 +0200 Subject: [PATCH 010/166] Implement to caching : self._cached_spike_vector + self._cached_spike_trains --- src/spikeinterface/core/basesorting.py | 125 +++++++++++++----- src/spikeinterface/core/numpyextractors.py | 4 + .../postprocessing/amplitude_scalings.py | 3 +- 3 files changed, 96 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 3712306c28..96455f54e3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -23,6 +23,10 @@ def __init__(self, sampling_frequency: float, unit_ids: List): self._recording = None self._sorting_info = None + # caching + self._cached_spike_vector = None + self._cached_spike_trains = {} + def __repr__(self): clsname = self.__class__.__name__ nseg = self.get_num_segments() @@ -109,12 +113,30 @@ def get_unit_spike_train( start_frame: Union[int, None] = None, end_frame: Union[int, None] = None, return_times: bool = False, + use_cache: bool = True, ): segment_index = self._check_segment_index(segment_index) - segment = self._sorting_segments[segment_index] - spike_frames = segment.get_unit_spike_train( - unit_id=unit_id, start_frame=start_frame, end_frame=end_frame - ).astype("int64") + if use_cache: + if segment_index not in self._cached_spike_trains: + self._cached_spike_trains[segment_index] = {} + if unit_id not in self._cached_spike_trains[segment_index]: + segment = self._sorting_segments[segment_index] + spike_frames = segment.get_unit_spike_train(unit_id=unit_id, + start_frame=None, + end_frame=None).astype("int64") + self._cached_spike_trains[segment_index][unit_id] = spike_frames + else: + spike_frames = self._cached_spike_trains[segment_index][unit_id] + if start_frame is not None: + spike_frames = spike_frames[spike_frames >= start_frame] + if end_frame is not None: + spike_frames = spike_frames[spike_frames < end_frame] + else: + segment = self._sorting_segments[segment_index] + spike_frames = segment.get_unit_spike_train( + unit_id=unit_id, start_frame=start_frame, end_frame=end_frame + ).astype("int64") + if return_times: if self.has_recording(): times = self.get_times(segment_index=segment_index) @@ -346,7 +368,7 @@ def get_all_spike_trains(self, outputs="unit_id"): spikes.append((spike_times, spike_labels)) return spikes - def to_spike_vector(self, concatenated=True, extremum_channel_inds=None): + def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cache=True): """ Construct a unique structured numpy vector concatenating all spikes with several fields: sample_index, unit_index, segment_index. @@ -356,13 +378,16 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None): Parameters ---------- concatenated: bool - By default the output is one numpy vector. - With concatenated=False then it is a list of vector by segment. + By default the output is one numpy vector with all spikes from all segments + With concatenated=False then it is a list of spike vector by segment. extremum_channel_inds: None or dict If a dictionnary of unit_id to channel_ind is given then an extra field 'channel_index'. This can be convinient for computing spikes postion after sorter. This dict can be computed with `get_template_extremum_channel(we, outputs="index")` + use_cache: bool + When True (default) the spikes vector is cache in an attribute of the object. + This caching only occurs when extremum_channel_inds=None. Returns ------- @@ -372,40 +397,70 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None): is given """ + spike_dtype = minimum_spike_dtype if extremum_channel_inds is not None: spike_dtype += [("channel_index", "int64")] - spikes = [] - for segment_index in range(self.get_num_segments()): - sample_indices = [] - unit_indices = [] - for u, unit_id in enumerate(self.unit_ids): - spike_times = st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - sample_indices.append(spike_times) - unit_indices.append(np.full(spike_times.size, u, dtype="int64")) - - if len(sample_indices) > 0: - sample_indices = np.concatenate(sample_indices, dtype='int64') - unit_indices = np.concatenate(unit_indices, dtype='int64') - order = np.argsort(sample_indices) - sample_indices = sample_indices[order] - unit_indices = unit_indices[order] - - spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) - spikes_in_seg["sample_index"] = sample_indices - spikes_in_seg["unit_index"] = unit_indices - spikes_in_seg["segment_index"] = segment_index - spikes.append(spikes_in_seg) - + if use_cache and self._cached_spike_vector is not None: + # the cache already exists if extremum_channel_inds is not None: - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) - # vector way - spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] - - if concatenated: - spikes = np.concatenate(spikes) + spikes = self._cached_spike_vector + else: + spikes = np.zeros(self._cached_spike_vector.size, dtype=spike_dtype) + spikes["sample_index"] = self._cached_spike_vector["sample_index"] + spikes["unit_index"] = self._cached_spike_vector["unit_index"] + spikes["segment_index"] = self._cached_spike_vector["segment_index"] + if extremum_channel_inds is not None: + ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) + spikes["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] + + if not concatenated: + spikes_ = [] + for segment_index in range(self.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index, side="left") + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1, side="left") + spikes_.append(spikes[s0:s1]) + spikes = spikes_ + else: + # the cache not needed or do not exists yet + spikes = [] + for segment_index in range(self.get_num_segments()): + sample_indices = [] + unit_indices = [] + for u, unit_id in enumerate(self.unit_ids): + spike_times = st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + sample_indices.append(spike_times) + unit_indices.append(np.full(spike_times.size, u, dtype="int64")) + + if len(sample_indices) > 0: + sample_indices = np.concatenate(sample_indices, dtype='int64') + unit_indices = np.concatenate(unit_indices, dtype='int64') + order = np.argsort(sample_indices) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] + + spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) + spikes_in_seg["sample_index"] = sample_indices + spikes_in_seg["unit_index"] = unit_indices + spikes_in_seg["segment_index"] = segment_index + spikes.append(spikes_in_seg) + + if extremum_channel_inds is not None: + ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) + # vector way + spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] + + if concatenated: + spikes = np.concatenate(spikes) + + if use_cache and self._cached_spike_vector is None and extremum_channel_inds is None: + # cache it if necessary but only without "channel_index" + if concatenated: + self._cached_spike_vector = spikes + else: + self._cached_spike_vector = np.concatenate(spikes) return spikes diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 35febad6d7..c568a94b5b 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -118,6 +118,7 @@ class NumpySorting(BaseSorting): def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) + self.is_dumpable = True if spikes.size == 0: @@ -127,6 +128,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): for segment_index in range(nseg): self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) + + # important trick : the cache is already spikes vector + self._cached_spike_vector = spikes self._kwargs = dict(spikes=spikes, sampling_frequency=sampling_frequency, unit_ids=unit_ids) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index dc3624ba3e..62df2f42cc 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -18,7 +18,8 @@ def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") - self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, + use_cache=False) def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) From 5746aa61cf67b2e0c78b5d2e97e3fc8519350eb4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 14:55:55 +0000 Subject: [PATCH 011/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 13 ++++++------- src/spikeinterface/core/numpyextractors.py | 2 +- .../postprocessing/amplitude_scalings.py | 5 +++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8251761ec0..e206ac7eee 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -121,9 +121,9 @@ def get_unit_spike_train( self._cached_spike_trains[segment_index] = {} if unit_id not in self._cached_spike_trains[segment_index]: segment = self._sorting_segments[segment_index] - spike_frames = segment.get_unit_spike_train(unit_id=unit_id, - start_frame=None, - end_frame=None).astype("int64") + spike_frames = segment.get_unit_spike_train(unit_id=unit_id, start_frame=None, end_frame=None).astype( + "int64" + ) self._cached_spike_trains[segment_index][unit_id] = spike_frames else: spike_frames = self._cached_spike_trains[segment_index][unit_id] @@ -402,7 +402,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if extremum_channel_inds is not None: spike_dtype += [("channel_index", "int64")] - if use_cache and self._cached_spike_vector is not None: # the cache already exists if extremum_channel_inds is not None: @@ -436,8 +435,8 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac unit_indices.append(np.full(spike_times.size, u, dtype="int64")) if len(sample_indices) > 0: - sample_indices = np.concatenate(sample_indices, dtype='int64') - unit_indices = np.concatenate(unit_indices, dtype='int64') + sample_indices = np.concatenate(sample_indices, dtype="int64") + unit_indices = np.concatenate(unit_indices, dtype="int64") order = np.argsort(sample_indices) sample_indices = sample_indices[order] unit_indices = unit_indices[order] @@ -455,7 +454,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac if concatenated: spikes = np.concatenate(spikes) - + if use_cache and self._cached_spike_vector is None and extremum_channel_inds is None: # cache it if necessary but only without "channel_index" if concatenated: diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index c568a94b5b..c17b89a296 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -128,7 +128,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): for segment_index in range(nseg): self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) - + # important trick : the cache is already spikes vector self._cached_spike_vector = spikes diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 62df2f42cc..3ebeafcfec 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -18,8 +18,9 @@ def __init__(self, waveform_extractor): BaseWaveformExtractorExtension.__init__(self, waveform_extractor) extremum_channel_inds = get_template_extremum_channel(self.waveform_extractor, outputs="index") - self.spikes = self.waveform_extractor.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds, - use_cache=False) + self.spikes = self.waveform_extractor.sorting.to_spike_vector( + extremum_channel_inds=extremum_channel_inds, use_cache=False + ) def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) From 5173233fb434c0913f684d1ac0f6b57edef76e23 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 19:20:16 +0200 Subject: [PATCH 012/166] oups --- src/spikeinterface/core/basesorting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8251761ec0..4ffca242d0 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -400,7 +400,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac spike_dtype = minimum_spike_dtype if extremum_channel_inds is not None: - spike_dtype += [("channel_index", "int64")] + spike_dtype = spike_dtype + [("channel_index", "int64")] if use_cache and self._cached_spike_vector is not None: @@ -442,16 +442,16 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac sample_indices = sample_indices[order] unit_indices = unit_indices[order] - spikes_in_seg = np.zeros(len(sample_indices), dtype=minimum_spike_dtype) + spikes_in_seg = np.zeros(len(sample_indices), dtype=spike_dtype) spikes_in_seg["sample_index"] = sample_indices spikes_in_seg["unit_index"] = unit_indices spikes_in_seg["segment_index"] = segment_index - spikes.append(spikes_in_seg) - if extremum_channel_inds is not None: ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) # vector way spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] + spikes.append(spikes_in_seg) + if concatenated: spikes = np.concatenate(spikes) From d759d87064c46e75c4990e6d9d9226ccb76afa9f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 19:21:04 +0200 Subject: [PATCH 013/166] wip --- src/spikeinterface/core/numpyextractors.py | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index c568a94b5b..d0e02a19ef 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -10,6 +10,9 @@ BaseSnippetsSegment, ) from .basesorting import minimum_spike_dtype +from .core_tools import make_shared_array + +from multiprocessing.shared_memory import SharedMemory from typing import List, Union @@ -342,6 +345,27 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): return times +# class SharedMemmorySorting(BaseSorting): +# def __init__(self, shm_name, shape, dtype=minimum_spike_dtype): + +# shm = SharedMemory(shm_name) +# arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + + + + +# for segment_index in range(nseg): +# self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) + +# @staticmethod +# def from_sorting(source_sorting: BaseSorting) -> "SharedMemmorySorting": + +# make_shared_array(shape, dtype) + + + + + class NumpyEvent(BaseEvent): def __init__(self, channel_ids, structured_dtype): BaseEvent.__init__(self, channel_ids, structured_dtype) From 6e9c93164f5a74e37186f3f03b54a7e04de574f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 17:21:47 +0000 Subject: [PATCH 014/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 1 - src/spikeinterface/core/numpyextractors.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 3f3f5f2075..1cb094c8bc 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -451,7 +451,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] spikes.append(spikes_in_seg) - if concatenated: spikes = np.concatenate(spikes) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 331ef5a182..5373746f49 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -352,8 +352,6 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): # arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - - # for segment_index in range(nseg): # self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) @@ -363,9 +361,6 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): # make_shared_array(shape, dtype) - - - class NumpyEvent(BaseEvent): def __init__(self, channel_ids, structured_dtype): BaseEvent.__init__(self, channel_ids, structured_dtype) From bc0fd670381e9e022384c1f1e6f8613424435f09 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 22:15:10 +0200 Subject: [PATCH 015/166] Various fix due to to_spike_vector() refactoring --- src/spikeinterface/core/basesorting.py | 8 ++++---- .../postprocessing/correlograms.py | 2 +- .../postprocessing/principal_component.py | 2 ++ .../postprocessing/spike_locations.py | 3 --- .../postprocessing/tests/test_correlograms.py | 16 ++++++++-------- .../tests/test_principal_component.py | 8 ++++---- 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 3f3f5f2075..8d4c05c704 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -401,10 +401,11 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac spike_dtype = minimum_spike_dtype if extremum_channel_inds is not None: spike_dtype = spike_dtype + [("channel_index", "int64")] + ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) if use_cache and self._cached_spike_vector is not None: # the cache already exists - if extremum_channel_inds is not None: + if extremum_channel_inds is None: spikes = self._cached_spike_vector else: spikes = np.zeros(self._cached_spike_vector.size, dtype=spike_dtype) @@ -412,8 +413,8 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac spikes["unit_index"] = self._cached_spike_vector["unit_index"] spikes["segment_index"] = self._cached_spike_vector["segment_index"] if extremum_channel_inds is not None: - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) - spikes["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] + + spikes["channel_index"] = ext_channel_inds[spikes["unit_index"]] if not concatenated: spikes_ = [] @@ -446,7 +447,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac spikes_in_seg["unit_index"] = unit_indices spikes_in_seg["segment_index"] = segment_index if extremum_channel_inds is not None: - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in self.unit_ids]) # vector way spikes_in_seg["channel_index"] = ext_channel_inds[spikes_in_seg["unit_index"]] spikes.append(spikes_in_seg) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 3cb8e9b96b..6cd5238abd 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -306,7 +306,7 @@ def compute_correlograms_numba(sorting, window_size, bin_size): num_bins = 2 * int(window_size / bin_size) num_units = len(sorting.unit_ids) - spikes = sorting.to_spike_vector(concatenated=false) + spikes = sorting.to_spike_vector(concatenated=False) correlograms = np.zeros((num_units, num_units, num_bins), dtype=np.int64) for seg_index in range(sorting.get_num_segments()): diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 48d4f49ebe..722bd9b7a7 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -308,6 +308,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): file_path = Path(file_path) spikes = sorting.to_spike_vector(concatenated=False) + # This is the first segment only + spikes = spikes[0] spike_times = spikes["sample_index"] spike_labels = spikes["unit_index"] diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index aac96be7b6..c6f498f7e8 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -50,9 +50,6 @@ def _run(self, **job_kwargs): we = self.waveform_extractor - extremum_channel_inds = get_template_extremum_channel(we, outputs="index") - self.spikes = we.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) - spike_locations = localize_peaks(we.recording, self.spikes, **self._params, **job_kwargs) self._extension_data["spike_locations"] = spike_locations diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 9c3529345b..d6648150de 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -204,13 +204,13 @@ def test_detect_injected_correlation(): if __name__ == "__main__": - # ~ test_make_bins() - # test_equal_results_correlograms() - # ~ test_flat_cross_correlogram() - # ~ test_auto_equal_cross_correlograms() + test_make_bins() + test_equal_results_correlograms() + test_flat_cross_correlogram() + test_auto_equal_cross_correlograms() test_detect_injected_correlation() - # test = CorrelogramsExtensionTest() - # test.setUp() - # test.test_compute_correlograms() - # test.test_extension() + test = CorrelogramsExtensionTest() + test.setUp() + test.test_compute_correlograms() + test.test_extension() diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index b73477d306..5d64525b52 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -197,8 +197,8 @@ def test_project_new(self): if __name__ == "__main__": test = PrincipalComponentsExtensionTest() test.setUp() - # test.test_extension() - # test.test_shapes() - # test.test_compute_for_all_spikes() - # test.test_sparse() + test.test_extension() + test.test_shapes() + test.test_compute_for_all_spikes() + test.test_sparse() test.test_project_new() From 5e83cfbe938b58eca9871b09341622c1ba715f05 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 20:15:44 +0000 Subject: [PATCH 016/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 1 - src/spikeinterface/postprocessing/principal_component.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 31183678f7..e05295f2fe 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -413,7 +413,6 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac spikes["unit_index"] = self._cached_spike_vector["unit_index"] spikes["segment_index"] = self._cached_spike_vector["segment_index"] if extremum_channel_inds is not None: - spikes["channel_index"] = ext_channel_inds[spikes["unit_index"]] if not concatenated: diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 722bd9b7a7..9465b16db6 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -308,7 +308,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): file_path = Path(file_path) spikes = sorting.to_spike_vector(concatenated=False) - # This is the first segment only + # This is the first segment only spikes = spikes[0] spike_times = spikes["sample_index"] spike_labels = spikes["unit_index"] From 9694eeb55c1685f48f626600b586995250ce1d01 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 22:55:40 +0200 Subject: [PATCH 017/166] Basic SharedMemmorySorting implementation. --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/basesorting.py | 19 +++++++ src/spikeinterface/core/numpyextractors.py | 45 ++++++++++++---- .../core/tests/test_numpy_extractors.py | 51 +++++++++++++++++-- 4 files changed, 102 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index e379777a44..d85aec5787 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -8,7 +8,7 @@ # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting -from .numpyextractors import NumpyRecording, NumpySorting, NumpyEvent, NumpySnippets +from .numpyextractors import NumpyRecording, NumpySorting, SharedMemmorySorting, NumpyEvent, NumpySnippets from .zarrrecordingextractor import ZarrRecordingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder from .npzfolder import NpzFolderSorting, read_npz_folder diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 31183678f7..f112ebf2ba 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -463,6 +463,25 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac return spikes + def to_numpy_sorting(self): + """ + Turn any sorting in a NumpySorting. + usefull to have it in memory with a unique vector representation. + """ + from .numpyextractors import NumpySorting + sorting = NumpySorting.from_sorting(self) + return sorting + + def to_shared_memmory_sorting(self): + """ + Turn any sorting in a SharedMemmorySorting. + Usefull to have it in memory with a unique vector representation and sharable acros processes. + """ + from .numpyextractors import SharedMemmorySorting + sorting = SharedMemmorySorting.from_sorting(self) + return sorting + + class BaseSortingSegment(BaseSegment): """ diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 5373746f49..9377373a4b 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -345,20 +345,47 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): return times -# class SharedMemmorySorting(BaseSorting): -# def __init__(self, shm_name, shape, dtype=minimum_spike_dtype): +class SharedMemmorySorting(BaseSorting): + def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_spike_dtype): + assert len(shape) == 1 + assert shape[0] > 0, 'SharedMemmorySorting only supported with no empty sorting' -# shm = SharedMemory(shm_name) -# arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + BaseSorting.__init__(self, sampling_frequency, unit_ids) + self.is_dumpable = True + self.shm = SharedMemory(shm_name, create=False) + self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) -# for segment_index in range(nseg): -# self.add_sorting_segment(NumpySortingSegment(spikes, segment_index, unit_ids)) + nseg = self.shm_spikes[-1]["segment_index"] + 1 + for segment_index in range(nseg): + self.add_sorting_segment(NumpySortingSegment(self.shm_spikes, segment_index, unit_ids)) -# @staticmethod -# def from_sorting(source_sorting: BaseSorting) -> "SharedMemmorySorting": + # important trick : the cache is already spikes vector + self._cached_spike_vector = self.shm_spikes + + self._kwargs = dict(shm_name=shm_name, shape=shape, + sampling_frequency=sampling_frequency, unit_ids=unit_ids) + + def __del__(self): + # this try to avoid + # "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" + # But still nedd investigation because do not work + print('__del__') + self._segments = None + self.shm_spikes = None + self.shm.close() + self.shm = None + print('after __del__') -# make_shared_array(shape, dtype) + @staticmethod + def from_sorting(source_sorting): + spikes = source_sorting.to_spike_vector() + shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) + shm_spikes[:] = spikes + sorting = SharedMemmorySorting(shm.name, spikes.shape, source_sorting.get_sampling_frequency(), + source_sorting.unit_ids, dtype=spikes.dtype) + shm.close() + return sorting class NumpyEvent(BaseEvent): diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 9970b8b8b0..36e1be2d2a 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -4,8 +4,8 @@ import pytest import numpy as np -from spikeinterface.core import NumpyRecording, NumpySorting, NumpyEvent -from spikeinterface.core import create_sorting_npz +from spikeinterface.core import NumpyRecording, NumpySorting, SharedMemmorySorting, NumpyEvent +from spikeinterface.core import create_sorting_npz, load_extractor from spikeinterface.core import NpzSortingExtractor from spikeinterface.core.basesorting import minimum_spike_dtype @@ -62,6 +62,46 @@ def test_NumpySorting(): sorting = NumpySorting.from_sorting(other_sorting) # print(sorting) + # TODO test too_dict()/ + # TODO some test on caching + + + +def test_SharedMemmorySorting(): + sampling_frequency = 30000 + unit_ids = ['a', 'b', 'c'] + spikes = np.zeros(100, dtype=minimum_spike_dtype) + spikes["sample_index"][:] = np.arange(0, 1000, 10, dtype='int64') + spikes["unit_index"][0::3] = 0 + spikes["unit_index"][1::3] = 1 + spikes["unit_index"][2::3] = 2 + np_sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + print(np_sorting) + + sorting = SharedMemmorySorting.from_sorting(np_sorting) + # print(sorting) + assert sorting._cached_spike_vector is not None + + print(sorting.to_spike_vector()) + d = sorting.to_dict() + + sorting_reload = load_extractor(d) + # print(sorting_reload) + print(sorting_reload.to_spike_vector()) + + assert sorting.shm.name == sorting_reload.shm.name + + # this try to avoid + # "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" + # But still need investigation because do not work + del sorting_reload + del sorting + + + + + + def test_NumpyEvent(): # one segment - dtype simple @@ -102,6 +142,7 @@ def test_NumpyEvent(): if __name__ == "__main__": - test_NumpyRecording() - test_NumpySorting() - test_NumpyEvent() + # test_NumpyRecording() + # test_NumpySorting() + test_SharedMemmorySorting() + # test_NumpyEvent() From fe7ec72fc99ee6a1f3da2b189c1c74fa17feac1c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 20:56:14 +0000 Subject: [PATCH 018/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 5 +++-- src/spikeinterface/core/numpyextractors.py | 16 ++++++++-------- .../core/tests/test_numpy_extractors.py | 18 ++++++------------ 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index c577329148..9beb36d958 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -468,20 +468,21 @@ def to_numpy_sorting(self): usefull to have it in memory with a unique vector representation. """ from .numpyextractors import NumpySorting + sorting = NumpySorting.from_sorting(self) return sorting - + def to_shared_memmory_sorting(self): """ Turn any sorting in a SharedMemmorySorting. Usefull to have it in memory with a unique vector representation and sharable acros processes. """ from .numpyextractors import SharedMemmorySorting + sorting = SharedMemmorySorting.from_sorting(self) return sorting - class BaseSortingSegment(BaseSegment): """ Abstract class representing several units and relative spiketrain inside a segment. diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 9377373a4b..df5bb8cb97 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -348,7 +348,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): class SharedMemmorySorting(BaseSorting): def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_spike_dtype): assert len(shape) == 1 - assert shape[0] > 0, 'SharedMemmorySorting only supported with no empty sorting' + assert shape[0] > 0, "SharedMemmorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = True @@ -363,27 +363,27 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ # important trick : the cache is already spikes vector self._cached_spike_vector = self.shm_spikes - self._kwargs = dict(shm_name=shm_name, shape=shape, - sampling_frequency=sampling_frequency, unit_ids=unit_ids) + self._kwargs = dict(shm_name=shm_name, shape=shape, sampling_frequency=sampling_frequency, unit_ids=unit_ids) def __del__(self): - # this try to avoid + # this try to avoid # "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" # But still nedd investigation because do not work - print('__del__') + print("__del__") self._segments = None self.shm_spikes = None self.shm.close() self.shm = None - print('after __del__') + print("after __del__") @staticmethod def from_sorting(source_sorting): spikes = source_sorting.to_spike_vector() shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) shm_spikes[:] = spikes - sorting = SharedMemmorySorting(shm.name, spikes.shape, source_sorting.get_sampling_frequency(), - source_sorting.unit_ids, dtype=spikes.dtype) + sorting = SharedMemmorySorting( + shm.name, spikes.shape, source_sorting.get_sampling_frequency(), source_sorting.unit_ids, dtype=spikes.dtype + ) shm.close() return sorting diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 36e1be2d2a..56c0b0a67b 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -66,12 +66,11 @@ def test_NumpySorting(): # TODO some test on caching - def test_SharedMemmorySorting(): sampling_frequency = 30000 - unit_ids = ['a', 'b', 'c'] + unit_ids = ["a", "b", "c"] spikes = np.zeros(100, dtype=minimum_spike_dtype) - spikes["sample_index"][:] = np.arange(0, 1000, 10, dtype='int64') + spikes["sample_index"][:] = np.arange(0, 1000, 10, dtype="int64") spikes["unit_index"][0::3] = 0 spikes["unit_index"][1::3] = 1 spikes["unit_index"][2::3] = 2 @@ -79,30 +78,25 @@ def test_SharedMemmorySorting(): print(np_sorting) sorting = SharedMemmorySorting.from_sorting(np_sorting) - # print(sorting) + # print(sorting) assert sorting._cached_spike_vector is not None print(sorting.to_spike_vector()) d = sorting.to_dict() sorting_reload = load_extractor(d) - # print(sorting_reload) + # print(sorting_reload) print(sorting_reload.to_spike_vector()) assert sorting.shm.name == sorting_reload.shm.name - - # this try to avoid + + # this try to avoid # "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" # But still need investigation because do not work del sorting_reload del sorting - - - - - def test_NumpyEvent(): # one segment - dtype simple d = { From e364257ea1222b0fe963a62d0902757e2c360aa8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 6 Jun 2023 23:11:38 +0200 Subject: [PATCH 019/166] Fix the "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" --- src/spikeinterface/core/basesorting.py | 4 +++- src/spikeinterface/core/numpyextractors.py | 21 ++++++++++--------- .../core/tests/test_numpy_extractors.py | 12 ++--------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index c577329148..c27067f700 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -333,7 +333,9 @@ def frame_slice(self, start_frame, end_frame): def get_all_spike_trains(self, outputs="unit_id"): """ - Return all spike trains concatenated + Return all spike trains concatenated. + + This is deprecated use sorting.to_spike_vector() instead """ warnings.warn( diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 9377373a4b..4d89f6b210 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -346,7 +346,8 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): class SharedMemmorySorting(BaseSorting): - def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_spike_dtype): + def __init__(self, shm_name, shape, sampling_frequency, unit_ids, + dtype=minimum_spike_dtype, main_shm_owner=True): assert len(shape) == 1 assert shape[0] > 0, 'SharedMemmorySorting only supported with no empty sorting' @@ -363,19 +364,19 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ # important trick : the cache is already spikes vector self._cached_spike_vector = self.shm_spikes + # this is very important for the shm.unlink() + # only the main instance need to call it + # all other instances that are loaded from dict are not the main owner + self.main_shm_owner = main_shm_owner + self._kwargs = dict(shm_name=shm_name, shape=shape, - sampling_frequency=sampling_frequency, unit_ids=unit_ids) + sampling_frequency=sampling_frequency, unit_ids=unit_ids, + main_shm_owner=False) def __del__(self): - # this try to avoid - # "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" - # But still nedd investigation because do not work - print('__del__') - self._segments = None - self.shm_spikes = None self.shm.close() - self.shm = None - print('after __del__') + if self.main_shm_owner: + self.shm.unlink() @staticmethod def from_sorting(source_sorting): diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 36e1be2d2a..d712cf8bc3 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -82,23 +82,15 @@ def test_SharedMemmorySorting(): # print(sorting) assert sorting._cached_spike_vector is not None - print(sorting.to_spike_vector()) + # print(sorting.to_spike_vector()) d = sorting.to_dict() sorting_reload = load_extractor(d) # print(sorting_reload) - print(sorting_reload.to_spike_vector()) + # print(sorting_reload.to_spike_vector()) assert sorting.shm.name == sorting_reload.shm.name - # this try to avoid - # "UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown" - # But still need investigation because do not work - del sorting_reload - del sorting - - - From d35ba802babc080f02d213f0d916bb0f2705de0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 21:15:31 +0000 Subject: [PATCH 020/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/numpyextractors.py | 14 ++++++++------ .../core/tests/test_numpy_extractors.py | 13 ++++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 547feebde7..19d1077288 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -346,8 +346,7 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): class SharedMemmorySorting(BaseSorting): - def __init__(self, shm_name, shape, sampling_frequency, unit_ids, - dtype=minimum_spike_dtype, main_shm_owner=True): + def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_spike_dtype, main_shm_owner=True): assert len(shape) == 1 assert shape[0] > 0, "SharedMemmorySorting only supported with no empty sorting" @@ -364,15 +363,18 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, # important trick : the cache is already spikes vector self._cached_spike_vector = self.shm_spikes - # this is very important for the shm.unlink() # only the main instance need to call it # all other instances that are loaded from dict are not the main owner self.main_shm_owner = main_shm_owner - self._kwargs = dict(shm_name=shm_name, shape=shape, - sampling_frequency=sampling_frequency, unit_ids=unit_ids, - main_shm_owner=False) + self._kwargs = dict( + shm_name=shm_name, + shape=shape, + sampling_frequency=sampling_frequency, + unit_ids=unit_ids, + main_shm_owner=False, + ) def __del__(self): self.shm.close() diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 7569a510a3..3abd6e108b 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -66,12 +66,11 @@ def test_NumpySorting(): # TODO some test on caching - def test_SharedMemmorySorting(): sampling_frequency = 30000 - unit_ids = ['a', 'b', 'c'] + unit_ids = ["a", "b", "c"] spikes = np.zeros(100, dtype=minimum_spike_dtype) - spikes["sample_index"][:] = np.arange(0, 1000, 10, dtype='int64') + spikes["sample_index"][:] = np.arange(0, 1000, 10, dtype="int64") spikes["unit_index"][0::3] = 0 spikes["unit_index"][1::3] = 1 spikes["unit_index"][2::3] = 2 @@ -79,21 +78,17 @@ def test_SharedMemmorySorting(): print(np_sorting) sorting = SharedMemmorySorting.from_sorting(np_sorting) - # print(sorting) + # print(sorting) assert sorting._cached_spike_vector is not None # print(sorting.to_spike_vector()) d = sorting.to_dict() sorting_reload = load_extractor(d) - # print(sorting_reload) + # print(sorting_reload) # print(sorting_reload.to_spike_vector()) assert sorting.shm.name == sorting_reload.shm.name - - - - def test_NumpyEvent(): From 20af0090091a142312bbb15a5e3ba1c291e6e115 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 7 Jun 2023 09:52:21 +0200 Subject: [PATCH 021/166] Use sorting.to_multiprocessing() when using ChunkRecordingExecutor --- src/spikeinterface/core/basesorting.py | 30 +++++++++++++++++++ .../postprocessing/principal_component.py | 28 +++++++++++------ .../postprocessing/spike_amplitudes.py | 2 +- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 2506d12562..b2ad631b30 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -483,6 +483,36 @@ def to_shared_memmory_sorting(self): sorting = SharedMemmorySorting.from_sorting(self) return sorting + + def to_multiprocessing(self, n_jobs): + """ + When necessary turn sorting object into: + * NumpySorting + * SharedMemmorySorting + * TODO add new format + + Parameters + ---------- + n_jobs: int + The number of jobs. + Returns + ------- + sharable_sorting: + A sorting that can be + + """ + from .numpyextractors import NumpySorting, SharedMemmorySorting + if n_jobs == 1: + if isinstance(self, (NumpySorting, SharedMemmorySorting)): + return self + else: + return NumpySorting.from_sorting(self) + else: + if isinstance(self, SharedMemmorySorting): + return self + else: + return SharedMemmorySorting.from_sorting(self) + class BaseSortingSegment(BaseSegment): diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 9465b16db6..3eb6570875 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -307,11 +307,11 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): file_path = self.extension_folder / "all_pcs.npy" file_path = Path(file_path) - spikes = sorting.to_spike_vector(concatenated=False) - # This is the first segment only - spikes = spikes[0] - spike_times = spikes["sample_index"] - spike_labels = spikes["unit_index"] + # spikes = sorting.to_spike_vector(concatenated=False) + # # This is the first segment only + # spikes = spikes[0] + # spike_times = spikes["sample_index"] + # spike_labels = spikes["unit_index"] sparsity = self.get_sparsity() if sparsity is None: @@ -330,7 +330,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): # nSpikes, nFeaturesPerChannel, nPCFeatures # this comes from phy template-gui # https://github.com/kwikteam/phy-contrib/blob/master/docs/template-gui.md#datasets - shape = (spike_times.size, p["n_components"], max_channels_per_template) + num_spikes = sorting.to_spike_vector().size + shape = (num_spikes, p["n_components"], max_channels_per_template) all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) @@ -339,9 +340,8 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): init_func = _init_work_all_pc_extractor init_args = ( recording, + sorting.to_multiprocessing(job_kwargs['n_jobs']), all_pcs_args, - spike_times, - spike_labels, we.nbefore, we.nafter, unit_channels, @@ -631,14 +631,24 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): def _init_work_all_pc_extractor( - recording, all_pcs_args, spike_times, spike_labels, nbefore, nafter, unit_channels, pca_model + recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model ): + worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) worker_ctx["recording"] = recording + worker_ctx["sorting"] = sorting + + spikes = sorting.to_spike_vector(concatenated=False) + # This is the first segment only + spikes = spikes[0] + spike_times = spikes["sample_index"] + spike_labels = spikes["unit_index"] + + worker_ctx["all_pcs"] = np.lib.format.open_memmap(**all_pcs_args) worker_ctx["spike_times"] = spike_times worker_ctx["spike_labels"] = spike_labels diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 6790e6d113..9eaeac5fc7 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -79,7 +79,7 @@ def _run(self, **job_kwargs): "The soring object is not dumpable and cannot be processed in parallel. You can use the " "`sorting.save()` function to make it dumpable" ) - init_args = (recording, sorting, extremum_channels_index, peak_shifts, return_scaled) + init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs ) From bec72d7dada99d25cb9ea7dc8ca356bf1fda41c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jun 2023 07:52:57 +0000 Subject: [PATCH 022/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 12 ++++++------ .../postprocessing/principal_component.py | 8 ++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index b2ad631b30..5c85addc3b 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -483,7 +483,7 @@ def to_shared_memmory_sorting(self): sorting = SharedMemmorySorting.from_sorting(self) return sorting - + def to_multiprocessing(self, n_jobs): """ When necessary turn sorting object into: @@ -496,12 +496,13 @@ def to_multiprocessing(self, n_jobs): n_jobs: int The number of jobs. Returns - ------- - sharable_sorting: - A sorting that can be - + ------- + sharable_sorting: + A sorting that can be + """ from .numpyextractors import NumpySorting, SharedMemmorySorting + if n_jobs == 1: if isinstance(self, (NumpySorting, SharedMemmorySorting)): return self @@ -514,7 +515,6 @@ def to_multiprocessing(self, n_jobs): return SharedMemmorySorting.from_sorting(self) - class BaseSortingSegment(BaseSegment): """ Abstract class representing several units and relative spiketrain inside a segment. diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 3eb6570875..c5793c9e1b 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -340,7 +340,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): init_func = _init_work_all_pc_extractor init_args = ( recording, - sorting.to_multiprocessing(job_kwargs['n_jobs']), + sorting.to_multiprocessing(job_kwargs["n_jobs"]), all_pcs_args, we.nbefore, we.nafter, @@ -630,10 +630,7 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): all_pcs[i, :, c] = pca_model[chan_ind].transform(w) -def _init_work_all_pc_extractor( - recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model -): - +def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model): worker_ctx = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor @@ -647,7 +644,6 @@ def _init_work_all_pc_extractor( spikes = spikes[0] spike_times = spikes["sample_index"] spike_labels = spikes["unit_index"] - worker_ctx["all_pcs"] = np.lib.format.open_memmap(**all_pcs_args) worker_ctx["spike_times"] = spike_times From 60ed688a179fe427e9578312b95ea45007a79771 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 7 Jun 2023 10:40:57 +0200 Subject: [PATCH 023/166] typos --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/basesorting.py | 18 +++++++++--------- src/spikeinterface/core/numpyextractors.py | 6 +++--- .../core/tests/test_numpy_extractors.py | 8 ++++---- .../core/tests/test_waveform_tools.py | 2 +- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index d85aec5787..cec747e070 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -8,7 +8,7 @@ # main extractor from dump and cache from .binaryrecordingextractor import BinaryRecordingExtractor, read_binary from .npzsortingextractor import NpzSortingExtractor, read_npz_sorting -from .numpyextractors import NumpyRecording, NumpySorting, SharedMemmorySorting, NumpyEvent, NumpySnippets +from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets from .zarrrecordingextractor import ZarrRecordingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder from .npzfolder import NpzFolderSorting, read_npz_folder diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index b2ad631b30..dd010fe1ee 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -474,21 +474,21 @@ def to_numpy_sorting(self): sorting = NumpySorting.from_sorting(self) return sorting - def to_shared_memmory_sorting(self): + def to_shared_memory_sorting(self): """ - Turn any sorting in a SharedMemmorySorting. + Turn any sorting in a SharedMemorySorting. Usefull to have it in memory with a unique vector representation and sharable acros processes. """ - from .numpyextractors import SharedMemmorySorting + from .numpyextractors import SharedMemorySorting - sorting = SharedMemmorySorting.from_sorting(self) + sorting = SharedMemorySorting.from_sorting(self) return sorting def to_multiprocessing(self, n_jobs): """ When necessary turn sorting object into: * NumpySorting - * SharedMemmorySorting + * SharedMemorySorting * TODO add new format Parameters @@ -501,17 +501,17 @@ def to_multiprocessing(self, n_jobs): A sorting that can be """ - from .numpyextractors import NumpySorting, SharedMemmorySorting + from .numpyextractors import NumpySorting, SharedMemorySorting if n_jobs == 1: - if isinstance(self, (NumpySorting, SharedMemmorySorting)): + if isinstance(self, (NumpySorting, SharedMemorySorting)): return self else: return NumpySorting.from_sorting(self) else: - if isinstance(self, SharedMemmorySorting): + if isinstance(self, SharedMemorySorting): return self else: - return SharedMemmorySorting.from_sorting(self) + return SharedMemorySorting.from_sorting(self) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 19d1077288..42327e5a5c 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -345,10 +345,10 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): return times -class SharedMemmorySorting(BaseSorting): +class SharedMemorySorting(BaseSorting): def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_spike_dtype, main_shm_owner=True): assert len(shape) == 1 - assert shape[0] > 0, "SharedMemmorySorting only supported with no empty sorting" + assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) self.is_dumpable = True @@ -386,7 +386,7 @@ def from_sorting(source_sorting): spikes = source_sorting.to_spike_vector() shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) shm_spikes[:] = spikes - sorting = SharedMemmorySorting( + sorting = SharedMemorySorting( shm.name, spikes.shape, source_sorting.get_sampling_frequency(), source_sorting.unit_ids, dtype=spikes.dtype ) shm.close() diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 3abd6e108b..36a7585e7c 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -4,7 +4,7 @@ import pytest import numpy as np -from spikeinterface.core import NumpyRecording, NumpySorting, SharedMemmorySorting, NumpyEvent +from spikeinterface.core import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent from spikeinterface.core import create_sorting_npz, load_extractor from spikeinterface.core import NpzSortingExtractor from spikeinterface.core.basesorting import minimum_spike_dtype @@ -66,7 +66,7 @@ def test_NumpySorting(): # TODO some test on caching -def test_SharedMemmorySorting(): +def test_SharedMemorySorting(): sampling_frequency = 30000 unit_ids = ["a", "b", "c"] spikes = np.zeros(100, dtype=minimum_spike_dtype) @@ -77,7 +77,7 @@ def test_SharedMemmorySorting(): np_sorting = NumpySorting(spikes, sampling_frequency, unit_ids) print(np_sorting) - sorting = SharedMemmorySorting.from_sorting(np_sorting) + sorting = SharedMemorySorting.from_sorting(np_sorting) # print(sorting) assert sorting._cached_spike_vector is not None @@ -132,5 +132,5 @@ def test_NumpyEvent(): if __name__ == "__main__": # test_NumpyRecording() # test_NumpySorting() - test_SharedMemmorySorting() + test_SharedMemorySorting() # test_NumpyEvent() diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index a896ff9c8b..8da47b1940 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -63,7 +63,7 @@ def test_waveform_tools(): wf_folder = cache_folder / f"test_waveform_tools_{j}" if wf_folder.is_dir(): shutil.rmtree(wf_folder) - wf_folder.mkdir() + wf_folder.mkdir(parents=True) # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) wfs_arrays = extract_waveforms_to_buffers( From 0b251f6368b5b0e4f9399fa1047497e1bafcae11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jun 2023 08:42:02 +0000 Subject: [PATCH 024/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8997b97672..675235053a 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -502,6 +502,7 @@ def to_multiprocessing(self, n_jobs): """ from .numpyextractors import NumpySorting, SharedMemorySorting + if n_jobs == 1: if isinstance(self, (NumpySorting, SharedMemorySorting)): return self From 411afb757ce679568b8049801a0210e07b7b4309 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 7 Jun 2023 14:11:31 +0200 Subject: [PATCH 025/166] Implement the NumpyFolderSorting as by default format. --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/basesorting.py | 34 +++-- src/spikeinterface/core/npzfolder.py | 55 -------- src/spikeinterface/core/numpyextractors.py | 2 +- src/spikeinterface/core/sortingfolder.py | 127 ++++++++++++++++++ src/spikeinterface/core/testing.py | 1 + .../core/tests/test_basesorting.py | 17 ++- .../core/tests/test_sorting_folder.py | 51 +++++++ .../core/tests/test_template_tools.py | 8 +- 9 files changed, 221 insertions(+), 76 deletions(-) delete mode 100644 src/spikeinterface/core/npzfolder.py create mode 100644 src/spikeinterface/core/sortingfolder.py create mode 100644 src/spikeinterface/core/tests/test_sorting_folder.py diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index cec747e070..0e93eb5877 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -11,7 +11,7 @@ from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets from .zarrrecordingextractor import ZarrRecordingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder -from .npzfolder import NpzFolderSorting, read_npz_folder +from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder_folder, read_npz_folder from .npysnippetsextractor import NpySnippetsExtractor, read_npy_snippets from .npyfoldersnippets import NpyFolderSnippets, read_npy_snippets_folder diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 8997b97672..0d943ea1a3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -215,27 +215,37 @@ def get_times(self, segment_index=None): else: return None - def _save(self, format="npz", **save_kwargs): + def _save(self, format="numpy_folder", **save_kwargs): """ This function replaces the old CachesortingExtractor, but enables more engines - for caching a results. At the moment only 'npz' is supported. - """ - if format == "npz": - folder = save_kwargs.pop("folder") - # TODO save properties/features as npz!!!!! - from .npzsortingextractor import NpzSortingExtractor + for caching a results. + + Since v0.98.0 'numpy_folder' is used by defult. + From v0.96.0 to 0.97.0 'npz_folder' was the default. - save_path = folder / "sorting_cached.npz" - NpzSortingExtractor.write_sorting(self, save_path) - cached = NpzSortingExtractor(save_path) - cached.dump(folder / "npz.json", relative_to=folder) - from .npzfolder import NpzFolderSorting + At the moment only 'npz' is supported. + """ + if format == "numpy_folder": + from .sortingfolder import NumpyFolderSorting + folder = save_kwargs.pop("folder") + NumpyFolderSorting.write_sorting(self, folder) + cached = NumpyFolderSorting(folder) + if self.has_recording(): + warnings.warn("The registered recording will not be persistent on disk, but only available in memory") + cached.register_recording(self._recording) + + elif format == "npz_folder": + from .sortingfolder import NpzFolderSorting + folder = save_kwargs.pop("folder") + NpzFolderSorting.write_sorting(self, folder) cached = NpzFolderSorting(folder_path=folder) + if self.has_recording(): warnings.warn("The registered recording will not be persistent on disk, but only available in memory") cached.register_recording(self._recording) + elif format == "memory": from .numpyextractors import NumpySorting diff --git a/src/spikeinterface/core/npzfolder.py b/src/spikeinterface/core/npzfolder.py deleted file mode 100644 index cd49ab472f..0000000000 --- a/src/spikeinterface/core/npzfolder.py +++ /dev/null @@ -1,55 +0,0 @@ -from pathlib import Path -import json - -import numpy as np - -from .base import _make_paths_absolute -from .npzsortingextractor import NpzSortingExtractor -from .core_tools import define_function_from_class - - -class NpzFolderSorting(NpzSortingExtractor): - """ - NpzFolderSorting is an internal format used in spikeinterface. - It is a NpzSortingExtractor + metadata contained in a folder. - - It is created with the function: `sorting.save(folder='/myfolder')` - - Parameters - ---------- - folder_path: str or Path - - Returns - ------- - sorting: NpzFolderSorting - The sorting - """ - - extractor_name = "NpzFolder" - has_default_locations = True - mode = "folder" - name = "npzfolder" - - def __init__(self, folder_path): - folder_path = Path(folder_path) - - with open(folder_path / "npz.json", "r") as f: - d = json.load(f) - - if not d["class"].endswith(".NpzSortingExtractor"): - raise ValueError("This folder is not an npz spikeinterface folder") - - assert d["relative_paths"] - - d = _make_paths_absolute(d, folder_path) - - NpzSortingExtractor.__init__(self, **d["kwargs"]) - - folder_metadata = folder_path - self.load_metadata_from_folder(folder_metadata) - - self._kwargs = dict(folder_path=str(folder_path.absolute())) - self._npz_kwargs = d["kwargs"] - - -read_npz_folder = define_function_from_class(source_class=NpzFolderSorting, name="read_npz_folder") diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 42327e5a5c..e349d5b06b 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -122,7 +122,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.is_dumpable = True + self.is_dumpable = False if spikes.size == 0: nseg = 1 diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py new file mode 100644 index 0000000000..24de52ec49 --- /dev/null +++ b/src/spikeinterface/core/sortingfolder.py @@ -0,0 +1,127 @@ +from pathlib import Path +import json + +import numpy as np + +from .base import _make_paths_absolute +from .basesorting import BaseSorting, BaseSortingSegment +from .npzsortingextractor import NpzSortingExtractor +from .core_tools import define_function_from_class +from .numpyextractors import NumpySortingSegment + + + +class NumpyFolderSorting(BaseSorting): + """ + NumpyFolderSorting is the new internal format used in spikeinterface (>=0.98.0) + + It is a simple folder that contains all flatten spikes (using sorting.to_spike_vector() in a npy format. + + It is created with the function: `sorting.save(folder='/myfolder', format="numpy_folder")` + + """ + extractor_name = "NumpyFolderSorting" + mode = "folder" + name = "NumpyFolder" + + def __init__(self, folder_path): + folder_path = Path(folder_path) + + with open(folder_path / "numpysorting_info.json", "r") as f: + d = json.load(f) + + sampling_frequency = d['sampling_frequency'] + unit_ids = np.array(d['unit_ids']) + num_segments = d['num_segments'] + + BaseSorting.__init__(self, sampling_frequency, unit_ids) + + self.spikes = np.load(folder_path / 'spikes.npy', mmap_mode='r') + + for segment_index in range(num_segments): + self.add_sorting_segment(NumpySortingSegment(self.spikes, segment_index, unit_ids)) + + # important trick : the cache is already spikes vector + self._cached_spike_vector = self.spikes + + folder_metadata = folder_path + self.load_metadata_from_folder(folder_metadata) + + self._kwargs = dict(folder_path=folder_path.absolute()) + + @staticmethod + def write_sorting(sorting, save_path): + # the folder can already exists but not contaning numpysorting_info.json + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + info_file = save_path / "numpysorting_info.json" + if info_file.exists(): + raise ValueError("NumpyFolderSorting.write_sorting the folder already contains numpysorting_info.json") + d = { + 'sampling_frequency': float(sorting.get_sampling_frequency()), + 'unit_ids': sorting.unit_ids.tolist(), + 'num_segments': sorting.get_num_segments(), + } + info_file.write_text(json.dumps(d), encoding="utf8") + np.save(save_path / 'spikes.npy', sorting.to_spike_vector()) + + +class NpzFolderSorting(NpzSortingExtractor): + """ + NpzFolderSorting is the old internal format used in spikeinterface (<=0.97.0) + It is a NpzSortingExtractor + metadata contained in a folder. + + It is created with the function: `sorting.save(folder='/myfolder', format="npz")` + + Parameters + ---------- + folder_path: str or Path + + Returns + ------- + sorting: NpzFolderSorting + The sorting + """ + + extractor_name = "NpzFolder" + mode = "folder" + name = "npzfolder" + + def __init__(self, folder_path): + folder_path = Path(folder_path) + + with open(folder_path / "npz.json", "r") as f: + d = json.load(f) + + if not d["class"].endswith(".NpzSortingExtractor"): + raise ValueError("This folder is not an npz spikeinterface folder") + + assert d["relative_paths"] + + d = _make_paths_absolute(d, folder_path) + + NpzSortingExtractor.__init__(self, **d["kwargs"]) + + folder_metadata = folder_path + self.load_metadata_from_folder(folder_metadata) + + self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._npz_kwargs = d["kwargs"] + + @staticmethod + def write_sorting(sorting, save_path): + # the folder can already exists but not contaning numpysorting_info.json + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) + + npz_file = save_path / "sorting_cached.npz" + NpzSortingExtractor.write_sorting(sorting, npz_file) + cached = NpzSortingExtractor(npz_file) + cached.dump(save_path / "npz.json", relative_to=save_path) + + +read_numpy_sorting_folder_folder = define_function_from_class( + source_class=NumpyFolderSorting, name="read_numpy_sorting_folder_folder" +) +read_npz_folder = define_function_from_class(source_class=NpzFolderSorting, name="read_npz_folder") diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index e337d0b035..1a13974b51 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -78,6 +78,7 @@ def check_sortings_equal( ) -> None: assert SX1.get_num_segments() == SX2.get_num_segments() + # TODO for later use to_spike_vector() to do this without looping for segment_idx in range(SX1.get_num_segments()): # get_unit_ids ids1 = np.sort(np.array(SX1.get_unit_ids())) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index f5307d3a28..141a50d3f1 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -13,6 +13,8 @@ NpzSortingExtractor, NumpyRecording, NumpySorting, + NpzFolderSorting, + NumpyFolderSorting, create_sorting_npz, generate_sorting, load_extractor, @@ -67,11 +69,20 @@ def test_BaseSorting(): check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=True) check_sortings_equal(sorting, sorting3, check_annotations=True, check_properties=True) - # cache - folder = cache_folder / "simple_sorting" + # cache old format : npz_folder + folder = cache_folder / "simple_sorting_npz_folder" sorting.set_property("test", np.ones(len(sorting.unit_ids))) - sorting.save(folder=folder) + sorting.save(folder=folder, format='npz_folder') sorting2 = BaseExtractor.load_from_folder(folder) + assert isinstance(sorting2, NpzFolderSorting) + + # cache new format : numpy_folder + folder = cache_folder / "simple_sorting_numpy_folder" + sorting.set_property("test", np.ones(len(sorting.unit_ids))) + sorting.save(folder=folder, format='numpy_folder') + sorting2 = BaseExtractor.load_from_folder(folder) + assert isinstance(sorting2, NumpyFolderSorting) + # but also possible sorting3 = BaseExtractor.load(folder) check_sortings_equal(sorting, sorting2, check_annotations=True, check_properties=True) diff --git a/src/spikeinterface/core/tests/test_sorting_folder.py b/src/spikeinterface/core/tests/test_sorting_folder.py new file mode 100644 index 0000000000..b1329765ad --- /dev/null +++ b/src/spikeinterface/core/tests/test_sorting_folder.py @@ -0,0 +1,51 @@ +import pytest + +from pathlib import Path +import shutil + +import numpy as np + +from spikeinterface.core import NpzFolderSorting, NumpyFolderSorting, load_extractor +from spikeinterface.core import generate_sorting +from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "core" +else: + cache_folder = Path("cache_folder") / "core" + + +def test_NumpyFolderSorting(): + sorting = generate_sorting() + + folder = cache_folder / "numpy_sorting_1" + if folder.is_dir(): + shutil.rmtree(folder) + + NumpyFolderSorting.write_sorting(sorting, folder) + + sorting_loaded = NumpyFolderSorting(folder) + check_sortings_equal(sorting_loaded, sorting) + assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) + assert np.array_equal(sorting_loaded.to_spike_vector(), sorting.to_spike_vector(), ) + + + +def test_NpzFolderSorting(): + sorting = generate_sorting() + + folder = cache_folder / "npz_folder_sorting_1" + if folder.is_dir(): + shutil.rmtree(folder) + + NpzFolderSorting.write_sorting(sorting, folder) + + sorting_loaded = NpzFolderSorting(folder) + check_sortings_equal(sorting_loaded, sorting) + assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) + assert np.array_equal(sorting_loaded.to_spike_vector(), sorting.to_spike_vector(), ) + + +if __name__ == "__main__": + test_NumpyFolderSorting() + test_NpzFolderSorting() diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index 9057659124..1a79019f96 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -82,7 +82,7 @@ def test_get_template_extremum_amplitude(): setup_module() test_get_template_amplitudes() - # test_get_template_extremum_channel() - # test_get_template_extremum_channel_peak_shift() - # test_get_template_extremum_amplitude() - # test_get_template_channel_sparsity() + test_get_template_extremum_channel() + test_get_template_extremum_channel_peak_shift() + test_get_template_extremum_amplitude() + test_get_template_channel_sparsity() From e5222372a43695d7f1a89b4fc1876881feed6885 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jun 2023 12:12:03 +0000 Subject: [PATCH 026/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/basesorting.py | 8 +++++--- src/spikeinterface/core/sortingfolder.py | 20 +++++++++---------- .../core/tests/test_basesorting.py | 4 ++-- .../core/tests/test_sorting_folder.py | 11 +++++++--- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 145f43d681..0978a4b405 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -219,15 +219,16 @@ def _save(self, format="numpy_folder", **save_kwargs): """ This function replaces the old CachesortingExtractor, but enables more engines for caching a results. - + Since v0.98.0 'numpy_folder' is used by defult. From v0.96.0 to 0.97.0 'npz_folder' was the default. At the moment only 'npz' is supported. """ - if format == "numpy_folder": + if format == "numpy_folder": from .sortingfolder import NumpyFolderSorting + folder = save_kwargs.pop("folder") NumpyFolderSorting.write_sorting(self, folder) cached = NumpyFolderSorting(folder) @@ -235,9 +236,10 @@ def _save(self, format="numpy_folder", **save_kwargs): if self.has_recording(): warnings.warn("The registered recording will not be persistent on disk, but only available in memory") cached.register_recording(self._recording) - + elif format == "npz_folder": from .sortingfolder import NpzFolderSorting + folder = save_kwargs.pop("folder") NpzFolderSorting.write_sorting(self, folder) cached = NpzFolderSorting(folder_path=folder) diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index 24de52ec49..55af059510 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -10,7 +10,6 @@ from .numpyextractors import NumpySortingSegment - class NumpyFolderSorting(BaseSorting): """ NumpyFolderSorting is the new internal format used in spikeinterface (>=0.98.0) @@ -20,6 +19,7 @@ class NumpyFolderSorting(BaseSorting): It is created with the function: `sorting.save(folder='/myfolder', format="numpy_folder")` """ + extractor_name = "NumpyFolderSorting" mode = "folder" name = "NumpyFolder" @@ -29,14 +29,14 @@ def __init__(self, folder_path): with open(folder_path / "numpysorting_info.json", "r") as f: d = json.load(f) - - sampling_frequency = d['sampling_frequency'] - unit_ids = np.array(d['unit_ids']) - num_segments = d['num_segments'] + + sampling_frequency = d["sampling_frequency"] + unit_ids = np.array(d["unit_ids"]) + num_segments = d["num_segments"] BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.spikes = np.load(folder_path / 'spikes.npy', mmap_mode='r') + self.spikes = np.load(folder_path / "spikes.npy", mmap_mode="r") for segment_index in range(num_segments): self.add_sorting_segment(NumpySortingSegment(self.spikes, segment_index, unit_ids)) @@ -59,12 +59,12 @@ def write_sorting(sorting, save_path): if info_file.exists(): raise ValueError("NumpyFolderSorting.write_sorting the folder already contains numpysorting_info.json") d = { - 'sampling_frequency': float(sorting.get_sampling_frequency()), - 'unit_ids': sorting.unit_ids.tolist(), - 'num_segments': sorting.get_num_segments(), + "sampling_frequency": float(sorting.get_sampling_frequency()), + "unit_ids": sorting.unit_ids.tolist(), + "num_segments": sorting.get_num_segments(), } info_file.write_text(json.dumps(d), encoding="utf8") - np.save(save_path / 'spikes.npy', sorting.to_spike_vector()) + np.save(save_path / "spikes.npy", sorting.to_spike_vector()) class NpzFolderSorting(NpzSortingExtractor): diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 141a50d3f1..99857803da 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -72,14 +72,14 @@ def test_BaseSorting(): # cache old format : npz_folder folder = cache_folder / "simple_sorting_npz_folder" sorting.set_property("test", np.ones(len(sorting.unit_ids))) - sorting.save(folder=folder, format='npz_folder') + sorting.save(folder=folder, format="npz_folder") sorting2 = BaseExtractor.load_from_folder(folder) assert isinstance(sorting2, NpzFolderSorting) # cache new format : numpy_folder folder = cache_folder / "simple_sorting_numpy_folder" sorting.set_property("test", np.ones(len(sorting.unit_ids))) - sorting.save(folder=folder, format='numpy_folder') + sorting.save(folder=folder, format="numpy_folder") sorting2 = BaseExtractor.load_from_folder(folder) assert isinstance(sorting2, NumpyFolderSorting) diff --git a/src/spikeinterface/core/tests/test_sorting_folder.py b/src/spikeinterface/core/tests/test_sorting_folder.py index b1329765ad..cf7cade3ef 100644 --- a/src/spikeinterface/core/tests/test_sorting_folder.py +++ b/src/spikeinterface/core/tests/test_sorting_folder.py @@ -27,8 +27,10 @@ def test_NumpyFolderSorting(): sorting_loaded = NumpyFolderSorting(folder) check_sortings_equal(sorting_loaded, sorting) assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) - assert np.array_equal(sorting_loaded.to_spike_vector(), sorting.to_spike_vector(), ) - + assert np.array_equal( + sorting_loaded.to_spike_vector(), + sorting.to_spike_vector(), + ) def test_NpzFolderSorting(): @@ -43,7 +45,10 @@ def test_NpzFolderSorting(): sorting_loaded = NpzFolderSorting(folder) check_sortings_equal(sorting_loaded, sorting) assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) - assert np.array_equal(sorting_loaded.to_spike_vector(), sorting.to_spike_vector(), ) + assert np.array_equal( + sorting_loaded.to_spike_vector(), + sorting.to_spike_vector(), + ) if __name__ == "__main__": From b65667b9fa3dad9c34a6781d37806852e2d67f92 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 7 Jun 2023 18:01:18 +0200 Subject: [PATCH 027/166] Good trick suggested by Alessio --- src/spikeinterface/core/numpyextractors.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index e349d5b06b..4ced32100a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -245,6 +245,9 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + # Trick : pupulate the cache with dict that already exists + sorting._cached_spike_trains = {seg_ind:d for seg_ind, d in enumerate(units_dict_list)} + return sorting @staticmethod From af0a28b5fab25f78cb3af0d9bcbb7b1bd313b372 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jun 2023 16:02:41 +0000 Subject: [PATCH 028/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/numpyextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 4ced32100a..5eabdde689 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -245,8 +245,8 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - # Trick : pupulate the cache with dict that already exists - sorting._cached_spike_trains = {seg_ind:d for seg_ind, d in enumerate(units_dict_list)} + # Trick : pupulate the cache with dict that already exists + sorting._cached_spike_trains = {seg_ind: d for seg_ind, d in enumerate(units_dict_list)} return sorting From b0230cb7f14b4d759e48804db5e94d8fa99d1eb2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 9 Jun 2023 16:13:56 +0200 Subject: [PATCH 029/166] get_total_num_spikes() > count_num_spikes_per_unit() add count_total_num_spikes() --- .../comparison/comparisontools.py | 4 +-- src/spikeinterface/core/basesorting.py | 31 +++++++++++++++++-- src/spikeinterface/curation/auto_merge.py | 4 +-- .../curation/remove_redundant.py | 2 +- src/spikeinterface/widgets/unit_depths.py | 2 +- 5 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index c01ea19f14..db45e2b25b 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -184,8 +184,8 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): unit1_ids = np.array(sorting1.get_unit_ids()) unit2_ids = np.array(sorting2.get_unit_ids()) - ev_counts1 = np.array(list(sorting1.get_total_num_spikes().values())) - ev_counts2 = np.array(list(sorting2.get_total_num_spikes().values())) + ev_counts1 = np.array(list(sorting1.count_num_spikes_per_unit().values())) + ev_counts2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 0978a4b405..eeefea17eb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -262,8 +262,16 @@ def get_unit_property(self, unit_id, key): return v def get_total_num_spikes(self): + warnings.warn( + "Sorting.get_total_num_spikes() is deprecated, se sorting.count_num_spikes_per_unit()", + DeprecationWarning, + stacklevel=2, + ) + return self.count_num_spikes_per_unit() + + def count_num_spikes_per_unit(self): """ - Get total number of spikes for each unit across segments. + For each unit : get number of spikes across segments. Returns ------- @@ -279,6 +287,17 @@ def get_total_num_spikes(self): num_spikes[unit_id] = n return num_spikes + def count_total_num_spikes(self): + """ + Get total number of spikes summed across segment and units. + + Returns + ------- + total_num_spikes: int + The total number of spike + """ + return self.to_spike_vector().size + def select_units(self, unit_ids, renamed_unit_ids=None): """ Selects a subset of units @@ -476,14 +495,22 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac return spikes - def to_numpy_sorting(self): + def to_numpy_sorting(self, propagate_cache=True): """ Turn any sorting in a NumpySorting. usefull to have it in memory with a unique vector representation. + + Parameters + ---------- + propagate_cache : bool + Propagate the cache of indivudual spike trains. + """ from .numpyextractors import NumpySorting sorting = NumpySorting.from_sorting(self) + if propagate_cache and self._cached_spike_trains is not None: + sorting._cached_spike_trains = self._cached_spike_trains return sorting def to_shared_memory_sorting(self): diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 680049dacd..5c11e911a6 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -137,7 +137,7 @@ def get_potential_auto_merge( # STEP 1 : if "min_spikes" in steps: - num_spikes = np.array(list(sorting.get_total_num_spikes().values())) + num_spikes = np.array(list(sorting.count_num_spikes_per_unit().values())) to_remove = num_spikes < minimum_spikes pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False @@ -256,7 +256,7 @@ def compute_correlogram_diff( # Index of the middle of the correlograms. m = correlograms_smoothed.shape[2] // 2 - num_spikes = sorting.get_total_num_spikes() + num_spikes = sorting.count_num_spikes_per_unit() corr_diff = np.full((n, n), np.nan, dtype="float64") for unit_ind1 in range(n): diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index 9f2e52ab5e..c2617d5b52 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -112,7 +112,7 @@ def remove_redundant_units( else: remove_unit_ids.append(u2) elif remove_strategy == "max_spikes": - num_spikes = sorting.get_total_num_spikes() + num_spikes = sorting.count_num_spikes_per_unit() for u1, u2 in redundant_unit_pairs: if num_spikes[u1] < num_spikes[u2]: remove_unit_ids.append(u1) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 5ceee0c133..b645c5e490 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -45,7 +45,7 @@ def __init__( unit_amplitudes = get_template_extremum_amplitude(we, peak_sign=peak_sign) unit_amplitudes = np.abs([unit_amplitudes[unit_id] for unit_id in unit_ids]) - num_spikes = np.array(list(we.sorting.get_total_num_spikes().values())) + num_spikes = np.array(list(we.sorting.count_num_spikes_per_unit().values())) plot_data = dict( unit_depths=unit_depths, From a16905cd6815e5b5e036351b426914e5a15ee6ed Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sun, 11 Jun 2023 19:46:20 +0200 Subject: [PATCH 030/166] Some fix and some docs --- doc/modules/core.rst | 23 +++++++++++++++++++ src/spikeinterface/core/basesorting.py | 16 +++++++------ src/spikeinterface/core/numpyextractors.py | 13 ++++++----- src/spikeinterface/core/sortingfolder.py | 22 +++++++++++++----- .../core/tests/test_basesorting.py | 11 ++++++++- 5 files changed, 65 insertions(+), 20 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 9af69768dd..af9159a109 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -137,6 +137,7 @@ It interfaces with a spike-sorted output and has the following features: * enable selection of sub-units * handle time information + Here we assume :code:`sorting` is a :py:class:`~spikeinterface.core.BaseSorting` object with 10 units: @@ -181,6 +182,12 @@ with 10 units: # times are not set, the samples are divided by the sampling frequency +Internally, any sorting object can construct 2 internal caches: + 1. a list (per segment) of dict (per unit) of numpy.array. This cache is usefull when accessing spiketrains unit + per unit across segments. + 2. a unique numpy.array with structured dtype aka "spikes vector". This is usefull for processing by small chunk of + time, like extract amplitudes from a recording. + WaveformExtractor ----------------- @@ -543,6 +550,10 @@ In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikein but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to +Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for +parallel computing. + In this example, we create a recording and a sorting object from numpy objects: .. code-block:: python @@ -574,6 +585,18 @@ In this example, we create a recording and a sorting object from numpy objects: sampling_frequency=sampling_frequency) +Any sorting object can be transformed into a :py:class:`~spikeinterface.core.NumpySorting` or +:py:class:`~spikeinterface.core.SharedMemorySorting` easily like this + +.. code-block:: python + + # turn any sortinto into NumpySorting + soring_np = sorting.to_numpy_sorting() + + # or to SharedMemorySorting for parrallel computing + sorting_shm = sorting.to_shared_memory_sorting() + + .. _multi_seg: Manipulating objects: slicing, aggregating diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index eeefea17eb..2135562bcd 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -526,9 +526,11 @@ def to_shared_memory_sorting(self): def to_multiprocessing(self, n_jobs): """ When necessary turn sorting object into: - * NumpySorting - * SharedMemorySorting - * TODO add new format + * NumpySorting when n_jobs=1 + * SharedMemorySorting whe, n_jobs>1 + + If the sorting is already NumpySorting, SharedMemorySorting or NumpyFolderSorting + then this return the sortign itself, no transformation so. Parameters ---------- @@ -537,18 +539,18 @@ def to_multiprocessing(self, n_jobs): Returns ------- sharable_sorting: - A sorting that can be - + A sorting that can be used for multiprocessing. """ from .numpyextractors import NumpySorting, SharedMemorySorting + from .sortingfolder import NumpyFolderSorting if n_jobs == 1: - if isinstance(self, (NumpySorting, SharedMemorySorting)): + if isinstance(self, (NumpySorting, SharedMemorySorting, NumpyFolderSorting)): return self else: return NumpySorting.from_sorting(self) else: - if isinstance(self, SharedMemorySorting): + if isinstance(self, (SharedMemorySorting, NumpyFolderSorting)): return self else: return SharedMemorySorting.from_sorting(self) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 5eabdde689..841f320aa5 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -153,7 +153,7 @@ def from_sorting(source_sorting: BaseSorting, with_metadata=False) -> "NumpySort @staticmethod def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None) -> "NumpySorting": """ - Construct sorting extractor from: + Construct NumpySorting extractor from: * an array of spike times (in frames) * an array of spike labels and adds all the In case of multisegment, it is a list of array. @@ -203,8 +203,8 @@ def from_times_labels(times_list, labels_list, sampling_frequency, unit_ids=None @staticmethod def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": """ - Construct sorting extractor from a list of dict. - The list length is the segment count + Construct NumpySorting from a list of dict. + The list length is the segment count. Each dict have unit_ids as keys and spike times as values. Parameters @@ -253,7 +253,7 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": @staticmethod def from_neo_spiketrain_list(neo_spiketrains, sampling_frequency, unit_ids=None) -> "NumpySorting": """ - Construct a sorting with a neo spiketrain list. + Construct a NumpySorting with a neo spiketrain list. If this is a list of list, it is multi segment. @@ -338,7 +338,6 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): self.spikes_in_seg = self.spikes[s0:s1] unit_index = self.unit_ids.index(unit_id) - times = self.spikes_in_seg[self.spikes_in_seg["unit_index"] == unit_index]["sample_index"] if start_frame is not None: @@ -376,6 +375,7 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ shape=shape, sampling_frequency=sampling_frequency, unit_ids=unit_ids, + # this ensure that all dump/load will not be main shm owner main_shm_owner=False, ) @@ -390,7 +390,8 @@ def from_sorting(source_sorting): shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) shm_spikes[:] = spikes sorting = SharedMemorySorting( - shm.name, spikes.shape, source_sorting.get_sampling_frequency(), source_sorting.unit_ids, dtype=spikes.dtype + shm.name, spikes.shape, source_sorting.get_sampling_frequency(), source_sorting.unit_ids, + dtype=spikes.dtype, main_shm_owner=True ) shm.close() return sorting diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index 55af059510..32b983f7e7 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -12,9 +12,13 @@ class NumpyFolderSorting(BaseSorting): """ - NumpyFolderSorting is the new internal format used in spikeinterface (>=0.98.0) + NumpyFolderSorting is the new internal format used in spikeinterface (>=0.98.0) for caching + sorting obecjts. - It is a simple folder that contains all flatten spikes (using sorting.to_spike_vector() in a npy format. + It is a simple folder that contains: + * a file "spike.npy" (numpy formt) with all flatten spikes (using sorting.to_spike_vector()) + * a "numpysorting_info.json" containing sampling_frequenc, unit_ids and num_segments + * a metadata folder for units properties. It is created with the function: `sorting.save(folder='/myfolder', format="numpy_folder")` @@ -47,7 +51,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=folder_path.absolute()) + self._kwargs = dict(folder_path=str(folder_path.absolute())) @staticmethod def write_sorting(sorting, save_path): @@ -70,9 +74,14 @@ def write_sorting(sorting, save_path): class NpzFolderSorting(NpzSortingExtractor): """ NpzFolderSorting is the old internal format used in spikeinterface (<=0.97.0) - It is a NpzSortingExtractor + metadata contained in a folder. - It is created with the function: `sorting.save(folder='/myfolder', format="npz")` + This a folder that contains: + + * "sorting_cached.npz" file in the NpzSortingExtractor format + * "npz.json" which the json description of NpzSortingExtractor + * a metadata folder for units properties. + + It is created with the function: `sorting.save(folder='/myfolder', format="npz_folder")` Parameters ---------- @@ -111,11 +120,12 @@ def __init__(self, folder_path): @staticmethod def write_sorting(sorting, save_path): - # the folder can already exists but not contaning numpysorting_info.json save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) npz_file = save_path / "sorting_cached.npz" + if npz_file.exists(): + raise ValueError("NpzFolderSorting.write_sorting the folder already contains sorting_cached.npz") NpzSortingExtractor.write_sorting(sorting, npz_file) cached = NpzSortingExtractor(npz_file) cached.dump(save_path / "npz.json", relative_to=save_path) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 99857803da..15653d4339 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -93,14 +93,19 @@ def test_BaseSorting(): check_sortings_equal(sorting, sorting4, check_annotations=True, check_properties=True) with pytest.warns(DeprecationWarning): - spikes = sorting.get_all_spike_trains() + num_spikes = sorting.get_all_spike_trains() # print(spikes) spikes = sorting.to_spike_vector() # print(spikes) + assert sorting._cached_spike_vector is not None spikes = sorting.to_spike_vector(extremum_channel_inds={0: 15, 1: 5, 2: 18}) # print(spikes) + num_spikes_per_unit = sorting.count_num_spikes_per_unit() + total_spikes = sorting.count_total_num_spikes() + + # select units keep_units = [0, 1] sorting_select = sorting.select_units(unit_ids=keep_units) @@ -113,6 +118,10 @@ def test_BaseSorting(): sorting_clean = sorting_empty.remove_empty_units() for unit in sorting_clean.get_unit_ids(): assert unit not in empty_units + + sorting4 = sorting.to_numpy_sorting() + sorting5 = sorting.to_multiprocessing(n_jobs=2) + del sorting5 def test_npy_sorting(): From 7d34c0c93624fd1457b8e2981f02236f12e25a4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Jun 2023 17:46:56 +0000 Subject: [PATCH 031/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/modules/core.rst | 2 +- src/spikeinterface/core/numpyextractors.py | 8 ++++++-- src/spikeinterface/core/tests/test_basesorting.py | 3 +-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index af9159a109..df939f1489 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -585,7 +585,7 @@ In this example, we create a recording and a sorting object from numpy objects: sampling_frequency=sampling_frequency) -Any sorting object can be transformed into a :py:class:`~spikeinterface.core.NumpySorting` or +Any sorting object can be transformed into a :py:class:`~spikeinterface.core.NumpySorting` or :py:class:`~spikeinterface.core.SharedMemorySorting` easily like this .. code-block:: python diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 841f320aa5..0c3026d961 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -390,8 +390,12 @@ def from_sorting(source_sorting): shm_spikes, shm = make_shared_array(spikes.shape, spikes.dtype) shm_spikes[:] = spikes sorting = SharedMemorySorting( - shm.name, spikes.shape, source_sorting.get_sampling_frequency(), source_sorting.unit_ids, - dtype=spikes.dtype, main_shm_owner=True + shm.name, + spikes.shape, + source_sorting.get_sampling_frequency(), + source_sorting.unit_ids, + dtype=spikes.dtype, + main_shm_owner=True, ) shm.close() return sorting diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 15653d4339..563f6cd4e8 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -105,7 +105,6 @@ def test_BaseSorting(): num_spikes_per_unit = sorting.count_num_spikes_per_unit() total_spikes = sorting.count_total_num_spikes() - # select units keep_units = [0, 1] sorting_select = sorting.select_units(unit_ids=keep_units) @@ -118,7 +117,7 @@ def test_BaseSorting(): sorting_clean = sorting_empty.remove_empty_units() for unit in sorting_clean.get_unit_ids(): assert unit not in empty_units - + sorting4 = sorting.to_numpy_sorting() sorting5 = sorting.to_multiprocessing(n_jobs=2) del sorting5 From 4dae8b605bfa7fe854c9e81368ea806019bca417 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Jul 2023 19:39:26 +0000 Subject: [PATCH 032/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/numpyextractors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index fe3f04fc74..400362573f 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -126,7 +126,6 @@ def __init__(self, spikes, sampling_frequency, unit_ids): self._is_dumpable = False self._is_json_serializable = False - if spikes.size == 0: nseg = 1 else: From 2e6f845eacf29d9f5d06966921718cc3fbc59efa Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 6 Jul 2023 22:04:30 +0200 Subject: [PATCH 033/166] _is_json_serializable --- src/spikeinterface/core/numpyextractors.py | 5 +++-- src/spikeinterface/core/tests/test_basesorting.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index fe3f04fc74..6bb85ae44c 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -123,7 +123,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = False + self._is_dumpable = True self._is_json_serializable = False @@ -356,7 +356,8 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.is_dumpable = True + self._is_dumpable = True + self._is_json_serializable = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 563f6cd4e8..fdb1b22d7d 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -31,6 +31,7 @@ def test_BaseSorting(): num_seg = 2 file_path = cache_folder / "test_BaseSorting.npz" + file_path.parent.mkdir(exist_ok=True) create_sorting_npz(num_seg, file_path) From 76d264b72e0d851f3a0b50cea05ef1e53692b818 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 10 Jul 2023 13:24:05 +0200 Subject: [PATCH 034/166] Add depth_order kwargs --- src/spikeinterface/preprocessing/depth_order.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index e30176f099..944b8d1f75 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -31,6 +31,11 @@ def __init__(self, parent_recording, channel_ids=None, dimensions=("x", "y")): parent_recording, channel_ids=reordered_channel_ids, ) + self._kwargs = dict( + parent_recording=parent_recording, + channel_ids=channel_ids, + dimensions=dimensions, + ) depth_order = define_function_from_class(source_class=DepthOrderRecording, name="depth_order") From 9d01efaacdea957c7e862771842dc96cb0f2f6fc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 10 Jul 2023 15:15:16 +0200 Subject: [PATCH 035/166] Fix has_channel_locations function --- src/spikeinterface/core/baserecordingsnippets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 259d3edc17..affde8a75e 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -58,7 +58,7 @@ def has_probe(self): return "contact_vector" in self.get_property_keys() def has_channel_location(self): - return self.has_probe() or "channel_location" in self.get_property_keys() + return self.has_probe() or "location" in self.get_property_keys() def is_filtered(self): # the is_filtered is handle with annotation From d4a59479b92963d32c52d3ba2d08afb7f75bf411 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 10 Jul 2023 13:34:38 -0400 Subject: [PATCH 036/166] typos--will need drift clarification --- doc/how_to/handle_drift.rst | 12 +++++------ doc/modules/core.rst | 4 ++-- doc/modules/motion_correction.rst | 34 +++++++++++++++---------------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index b59fa4dfcb..6bdc366cb7 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -7,7 +7,7 @@ Handle motion/drift with spikeinterface ======================================= -Spikeinterface offers a very flexible framework to handle drift as a +SpikeInterface offers a very flexible framework to handle drift as a preprocessing step. If you want to know more, please read the :ref:`motion_correction` section of the documentation. @@ -96,7 +96,7 @@ Correcting for drift is easy! You just need to run a single function. We will try this function with 3 presets. Internally a preset is a dictionary of dictionaries containing all -parameters for every steps. +parameters for each step. Here we also save the motion correction results into a folder to be able to load them later. @@ -185,14 +185,14 @@ A few comments on the figures: start moving is recovered quite well. * The preset **kilosort_like** gives better results because it is a non-rigid case. The motion vector is computed for different depths. The corrected peak locations are - flatter than the rigid case. The motion vector map is still be a bit - noisy at some depths (e.g around 1000um). + flatter than the rigid case. The motion vector map is still a bit + noisy at some depths (e.g. around 1000um). * The preset **nonrigid_accurate** seems to give the best results on this recording. The motion vector seems less noisy globally, but it is not “perfect” (see at the top of the probe 3200um to 3800um). Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion: the upper part of the probe - (2000-3000um) experience some drifts, but the lower part (0-1000um) is + (2000-3000um) experience some drift, but the lower part (0-1000um) is relatively stable. The method defined by this preset is able to capture this. .. code:: ipython3 @@ -237,7 +237,7 @@ axis, especially for the preset “nonrigid_accurate”. Be aware that there are two ways to correct for the motion: 1. Interpolate traces and detect/localize peaks again -(:py:func:`interpolate_recording()`) 2. Compensate for drifts directly on peak +(:py:func:`interpolate_recording()`) 2. Compensate for drift directly on peak locations (:py:func:`correct_motion_on_peaks()`) Case 1 is used before running a spike sorter and the case 2 is used here diff --git a/doc/modules/core.rst b/doc/modules/core.rst index f241d90df5..9923eacd39 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -494,7 +494,7 @@ In general, parallelization is achieved by splitting the recording in many small them in parallel (for more details, see the :py:class:`~spikeinterface.core.ChunkRecordingExecutor` class). Many functions support parallel processing (e.g., :py:func:`~spikeinterface.core.extract_waveforms`, :code:`save`, -and many more). All of this functions, in addition to other arguments, also accept the so-called **job_kwargs**. +and many more). All of these functions, in addition to other arguments, also accept the so-called **job_kwargs**. These are a set of keyword arguments which are common to all functions that support parallelization: * chunk_duration or chunk_size or chunk_memory or total_memory @@ -739,7 +739,7 @@ There are also some more advanced functions to generate sorting objects with var Downloading test datasets ------------------------- -The `NEO `_ package is maintaining a collection a files of many +The `NEO `_ package is maintaining a collection of many electrophysiology file formats: https://gin.g-node.org/NeuralEnsemble/ephy_testing_data The :py:func:`~spikeinterface.core.download_dataset` function is capable of downloading and caching locally dataset diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 1b582dbafc..6dc949625d 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -7,26 +7,26 @@ Motion/drift correction Overview -------- -Mechanical drifts, often observed in recordings, are currently a major issue for spike sorting. This is especially striking -with the new generation of high-density devices used in-vivo such as the neuropixel electrodes. -The first sorter that has introduced motion/drift correction as a prepossessing step was kilosort2.5 (see [Steinmetz2021]_) +Mechanical drift, often observed in recordings, is currently a major issue for spike sorting. This is especially striking +with the new generation of high-density devices used for in-vivo electrophyisology such as the neuropixel electrodes. +The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_) Long story short, the main idea is the same as the one used for non-rigid image registration, for example with calcium imaging. However, because with extracellular recording we do not have a proper image to use as a reference, the main idea of the algorithm is create an "image" via the activity profile of the cells during a given time window. Assuming this activity profile should be kept constant over time, the motion can be estimated, by blocks, along the probe's insertion axis -(i.e. depth) so that we can interpolate the traces to compensate this estimated motion. +(i.e. depth) so that we can interpolate the traces to compensate for this estimated motion. Users with a need to handle drift were currently forced to stick to the use of Kilosort2.5 or pyKilosort (see [Pachitariu2023]_). Recently, the Paninski group from Columbia University introduced a possibly more accurate method to estimate the drift (see [Varol2021]_ -and [Windolf2023]_), but this new method was not properly integrated in any sorter. +and [Windolf2023]_), but this new method was not properly integrated into any sorter. -Because motion registration is a hard topic, with numerous hypothesis and/or implementations details that might have a large +Because motion registration is a hard topic, with numerous hypotheses and/or implementations details that might have a large impact on the spike sorting performances (see [Garcia2023]_), in SpikeInterface, we developed a full motion estimation and interpolation framework to make all these methods accessible in one place. This modular approach offers a major benefit: **the drift correction can be applied to a recording as a preprocessing step, and then used for any sorter!** In short, the motion correction is decoupled from the sorter itself. -This gives the user an incredible flexibility to check/test and correct the drifts before the sorting process. +This gives the user an incredible flexibility to check/test and correct the drift before the sorting process. Here is an overview of the motion correction as a preprocessing: @@ -41,21 +41,21 @@ The motion correction process can be split into 3 steps: For every steps, we implemented several methods. The combination of the yellow boxes should give more or less what Kilosort2.5/3 is doing. Similarly, the combination of the green boxes gives the method developed by the Paninski group. -Of course the end user can combine any of the methods to get the best motion correction possible. -This make also an incredible framework for testing new ideas. +Of course the end user can combine any of these methods to get the best motion correction possible. +This also makes an incredible framework for testing new ideas. For a better overview, checkout our recent paper to validate, benchmark, and compare these motion correction methods (see [Garcia2023]_). SpikeInterface offers two levels for motion correction: 1. A high level with a unique function and predefined parameter presets - 2. A low level where the user need to call one by one all functions for a better control + 2. A low level where the user needs to call one by one all functions for better control High-level API -------------- -One challenging task for motion correction is to find parameters. +One challenging task for motion correction is to determine the parameters. The high level :py:func:`~spikeinterface.preprocessing.correct_motion()` proposes the concept of a **"preset"** that already has predefined parameters, in order to achieve a calibrated behavior. @@ -69,7 +69,7 @@ We currently have 3 presets: To be used as check and/or control on a recording to check the presence of drift. Note that, in this case the drift is considered as "rigid" over the electrode. * **"kilosort_like"**: It consists of *grid convolution + iterative_template + kriging*, to mimic what is done in Kilosort (see [Pachitariu2023]_). - Note that this is not exactly 100% what Kilosort is doing, because the peak detection is done with a template mathcing + Note that this is not exactly 100% what Kilosort is doing, because the peak detection is done with a template matching in Kilosort, while in SpikeInterface we used a threshold-based method. However, this "preset" gives similar results to Kilosort2.5. @@ -85,7 +85,7 @@ We currently have 3 presets: rec_corrected = correct_motion(rec, preset="nonrigid_accurate") The process is quite long due the two first steps (activity profile + motion inference) -But the return :code:`rec_corrected` is a lazy recording object this will interpolate traces on the +But the return :code:`rec_corrected` is a lazy recording object that will interpolate traces on the fly (step 3 motion interpolation). @@ -116,7 +116,7 @@ Optionally any parameter from the preset can be overwritten: ) ) -Importantly, all the result and intermediate computation can be saved into a folder for further loading +Importantly, all the result and intermediate computations can be saved into a folder for further loading and checking. The folder will contain the motion vector itself of course but also detected peaks, peak location, and more. @@ -134,11 +134,11 @@ Low-level API ------------- All steps (**activity profile**, **motion inference**, **motion interpolation**) can be launched with distinct functions. -This can be useful to find the good method and finely tune/optimize parameters at every steps. +This can be useful to find the best method and finely tune/optimize parameters at each step. All functions are implemented in the :py:mod:`~spikeinterface.sortingcomponents` module. They all have a simple API with SpikeInterface objects or numpy arrays as inputs. Since motion correction is a hot topic, these functions have many possible methods and also many possible parameters. -Finding the good combination of method/parameters is not that easy, but it should be doable, assuming the presets are not +Finding the best combination of method/parameters is not that easy, but it should be doable, assuming the presets are not working properly for your particular case. @@ -186,7 +186,7 @@ The function :py:func:`~spikeinterface.preprocessing.correct_motion()` requires It is important to keep in mind that the preprocessing can have a strong impact on the motion estimation. -In the context of motion correction we advice: +In the context of motion correction we advise: * to not use whitening before motion estimation (as it interferes with spatial amplitude information) * to remove high frequencies in traces, to reduce noise in peak location (e.g. using a bandpass filter) * if you use Neuropixels, then use :py:func:`~spikeinterface.preprocessing.phase_shift()` in preprocessing From ee13624f171de0380fa00a50c62bc0df422b8760 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Mon, 10 Jul 2023 14:08:35 -0400 Subject: [PATCH 037/166] hard code template writing to be float64 --- src/spikeinterface/exporters/to_phy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 33a46d2bea..58804037b0 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -168,7 +168,7 @@ def export_to_phy( # shape (num_units, num_samples, max_num_channels) max_num_channels = max(len(chan_inds) for chan_inds in sparse_dict.values()) num_samples = waveform_extractor.nbefore + waveform_extractor.nafter - templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype=waveform_extractor.dtype) + templates = np.zeros((len(unit_ids), num_samples, max_num_channels), dtype="float64") # here we pad template inds with -1 if len of sparse channels is unequal templates_ind = -np.ones((len(unit_ids), max_num_channels), dtype="int64") for unit_ind, unit_id in enumerate(unit_ids): From ab3f5709e7c66c3fb3a08a051dc941979371b296 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 11 Jul 2023 10:33:22 +0200 Subject: [PATCH 038/166] fix typo in class attribute --- src/spikeinterface/extractors/neoextractors/neuralynx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 6f73952eb1..58e97a69ef 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -61,7 +61,7 @@ class NeuralynxSortingExtractor(NeoBaseSortingExtractor): mode = "folder" NeoRawIOClass = "NeuralynxRawIO" - neo_returns_timestamps = False + neo_returns_frames = True need_t_start_from_signal_stream = True name = "neuralynx" From f779e12191515b49ff4f73a02d453f44e1be00ec Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 11 Jul 2023 10:42:08 +0200 Subject: [PATCH 039/166] Add docs requirements and build read-the-docs documentation faster (#1807) * build docs faster in branches * hope read the docs saves the copy locally and hope for the best * it WORKS! * adding numa * pin dependencies * feedback * another comment * Add comment reminder for release --------- Co-authored-by: Alessio Buccino --- doc/install_sorters.rst | 2 +- docs_rtd.yml | 9 +++++++++ pyproject.toml | 17 +++++++++++++++++ readthedocs.yml | 17 +++++------------ 4 files changed, 32 insertions(+), 13 deletions(-) create mode 100644 docs_rtd.yml diff --git a/doc/install_sorters.rst b/doc/install_sorters.rst index 1e55827ffd..3fda05848c 100644 --- a/doc/install_sorters.rst +++ b/doc/install_sorters.rst @@ -191,7 +191,7 @@ Mountainsort5 pip install mountainsort5 SpyKING CIRCUS -^^^^^^^^^^^^^ +^^^^^^^^^^^^^^ * Python, requires MPICH * Url: https://spyking-circus.readthedocs.io diff --git a/docs_rtd.yml b/docs_rtd.yml new file mode 100644 index 0000000000..c4e1fb378c --- /dev/null +++ b/docs_rtd.yml @@ -0,0 +1,9 @@ +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - datalad + - pip: + - -e .[docs] diff --git a/pyproject.toml b/pyproject.toml index 574cd79830..1b6d116e4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,23 @@ test = [ "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] +docs = [ + "Sphinx==5.1.1", + "sphinx_rtd_theme==1.0.0", + "sphinx-gallery", + "numpydoc", + + # for notebooks in the gallery + "MEArec", # Use as an example + "datalad==0.16.2", # Download mearec data, not sure if needed as is installed with conda as well because of git-annex + "pandas", # Don't know where this is needed + "hdbscan", # For sorters, probably spikingcircus + "numba", # For sorters, probably spikingcircus + # for release we need pypi, so this needs to be commented + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + +] [tool.pytest.ini_options] markers = [ diff --git a/readthedocs.yml b/readthedocs.yml index 350948104d..512fcbc709 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,17 +1,10 @@ version: 2 build: - image: latest + os: ubuntu-22.04 + tools: + python: "mambaforge-4.10" -conda: - environment: environment_rtd.yml - -# python: -# install: -# - method: pip -# path: . -# python: -# version: 3.8 -# install: -# - requirements: requirements_rtd.txt +conda: + environment: docs_rtd.yml From 5bafe5aad7c66b2576abf3a251a09ffe2df72950 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 12 Jul 2023 13:40:00 +0200 Subject: [PATCH 040/166] make recording optional, add amplitude_clim and alpha --- doc/how_to/handle_drift.rst | 4 +-- examples/how_to/handle_drift.py | 2 +- .../widgets/matplotlib/motion.py | 32 +++++++++++++---- src/spikeinterface/widgets/motion.py | 36 ++++++++++++++----- 4 files changed, 55 insertions(+), 19 deletions(-) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index 6bdc366cb7..a4fd13097d 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -204,8 +204,8 @@ A few comments on the figures: # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(rec, motion_info, figure=fig, depth_lim=(400, 600), - color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) + si.plot_motion(motion_info, recording=rec, figure=fig, depth_lim=(400, 600), + color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) fig.suptitle(f"{preset=}") diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index 9c2b09954e..7f4a4e2db4 100644 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -119,7 +119,7 @@ def preprocess_chain(rec): # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(rec, motion_info, figure=fig, depth_lim=(400, 600), + si.plot_motion(motion_info, recording=rec, figure=fig, depth_lim=(400, 600), color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) fig.suptitle(f"{preset=}") diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index abf02f4697..6ed4f2f685 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -3,6 +3,7 @@ from .base_mpl import MplPlotter import numpy as np +from matplotlib.colors import Normalize class MotionPlotter(MplPlotter): @@ -36,11 +37,16 @@ def do_plot(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim + if dp.times is None: + times = np.arange(np.max(dp.peaks["sample_index"]) + 1) / dp.sampling_frequency + else: + times = dp.times + corrected_location = correct_motion_on_peaks( - dp.peaks, dp.peak_locations, dp.rec.get_times(), dp.motion, dp.temporal_bins, dp.spatial_bins, direction="y" + dp.peaks, dp.peak_locations, times, dp.motion, dp.temporal_bins, dp.spatial_bins, direction="y" ) - x = dp.peaks["sample_index"] / dp.rec.get_sampling_frequency() + x = dp.peaks["sample_index"] / dp.sampling_frequency y = dp.peak_locations["y"] y2 = corrected_location["y"] if dp.scatter_decimate is not None: @@ -49,17 +55,26 @@ def do_plot(self, data_plot, **backend_kwargs): y2 = y2[:: dp.scatter_decimate] if dp.color_amplitude: - amps = np.abs(dp.peaks["amplitude"]) - amps /= np.quantile(amps, 0.95) + amps = dp.peaks["amplitude"] + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) if dp.scatter_decimate is not None: amps = amps[:: dp.scatter_decimate] - c = plt.get_cmap(dp.amplitude_cmap)(amps) + cmap = plt.get_cmap(dp.amplitude_cmap) + if dp.amplitude_clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) + c = cmap(norm_function(amps)) color_kwargs = dict( color=None, c=c, - ) # alpha=0.02 + alpha=dp.amplitude_alpha, + ) else: - color_kwargs = dict(color="k", c=None) # alpha=0.02 + color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) ax0.scatter(x, y, s=1, **color_kwargs) # for i in range(dp.motion.shape[1]): @@ -80,6 +95,7 @@ def do_plot(self, data_plot, **backend_kwargs): ax2.set_ylim(-motion_lim, motion_lim) ax2.set_ylabel("motion [um]") ax2.set_title("Motion vectors") + axes = [ax0, ax1, ax2] if not is_rigid: im = ax3.imshow( @@ -99,6 +115,8 @@ def do_plot(self, data_plot, **backend_kwargs): ax3.set_xlabel("Times [s]") ax3.set_ylabel("Depth [um]") ax3.set_title("Motion vectors") + axes.append(ax3) + self.axes = np.array(axes) MotionPlotter.register(MotionWidget) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 6a89050856..6feb597d46 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -14,43 +14,61 @@ class MotionWidget(BaseWidget): Parameters ---------- - recording : RecordingExtractor - The recording extractor object motion_info: dict The motion info return by correct_motion() or load back with load_motion_info() - depth_lim: tuple + recording : RecordingExtractor, optional + The recording extractor object (used to get sampling frequency and times), default None) + sampling_frequency : float, optional + The sampling frequency (needed if recording is None), default None + depth_lim : tuple The min and max depth to display, default None (min and max of the recording) - motion_lim: tuple + motion_lim : tuple The min and max motion to display, default None (min and max of the motion) - color_amplitude: bool + color_amplitude : bool If True, the color of the scatter points is the amplitude of the peaks, default False - scatter_decimate: int + scatter_decimate : int If > 1, the scatter points are decimated, default None - amplitude_cmap: str + amplitude_cmap : str The colormap to use for the amplitude, default 'inferno' + amplitude_clim : tuple + The min and max amplitude to display, default None (min and max of the amplitudes) + amplitude_alpha : float + The alpha of the scatter points, default 0.5 """ possible_backends = {} def __init__( self, - recording, motion_info, + recording=None, + sampling_frequency=None, depth_lim=None, motion_lim=None, color_amplitude=False, scatter_decimate=None, amplitude_cmap="inferno", + amplitude_clim=None, + amplitude_alpha=1, backend=None, **backend_kwargs, ): + assert recording or sampling_frequency, "recording or sampling_frequency must be provided" + if recording is not None: + sampling_frequency = recording.sampling_frequency + + times = recording.get_times() if recording is not None else None + plot_data = dict( - rec=recording, + sampling_frequency=sampling_frequency, + times=times, depth_lim=depth_lim, motion_lim=motion_lim, color_amplitude=color_amplitude, scatter_decimate=scatter_decimate, amplitude_cmap=amplitude_cmap, + amplitude_clim=amplitude_clim, + amplitude_alpha=amplitude_alpha, **motion_info, ) From 3b96e9ab4ea9e2e78853ea14fa89e4a9eb3b43cb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 12 Jul 2023 14:52:34 +0200 Subject: [PATCH 041/166] oups --- src/spikeinterface/widgets/matplotlib/motion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index 6ed4f2f685..46f0f7d3e3 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -60,6 +60,7 @@ def do_plot(self, data_plot, **backend_kwargs): q_95 = np.quantile(amps_abs, 0.95) if dp.scatter_decimate is not None: amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] cmap = plt.get_cmap(dp.amplitude_cmap) if dp.amplitude_clim is None: amps = amps_abs From 74a4d3d15fc0cd035fc8ffbc5d2dd684e3fc7aa4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 12 Jul 2023 15:02:25 +0200 Subject: [PATCH 042/166] consider t_start! --- src/spikeinterface/widgets/matplotlib/motion.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index 46f0f7d3e3..7556ec8674 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -37,13 +37,16 @@ def do_plot(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim + temporal_bins = dp.temporal_bins if dp.times is None: times = np.arange(np.max(dp.peaks["sample_index"]) + 1) / dp.sampling_frequency else: + # use real times and adjust temporal bins with t_start times = dp.times + temporal_bins += times[0] corrected_location = correct_motion_on_peaks( - dp.peaks, dp.peak_locations, times, dp.motion, dp.temporal_bins, dp.spatial_bins, direction="y" + dp.peaks, dp.peak_locations, times, dp.motion, temporal_bins, dp.spatial_bins, direction="y" ) x = dp.peaks["sample_index"] / dp.sampling_frequency @@ -78,8 +81,6 @@ def do_plot(self, data_plot, **backend_kwargs): color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) ax0.scatter(x, y, s=1, **color_kwargs) - # for i in range(dp.motion.shape[1]): - # ax0.plot(dp.temporal_bins, dp.motion[:, i] + dp.spatial_bins[i], color="C8", alpha=1.0) if dp.depth_lim is not None: ax0.set_ylim(*dp.depth_lim) ax0.set_title("Peak depth") @@ -91,8 +92,8 @@ def do_plot(self, data_plot, **backend_kwargs): ax1.set_ylabel("Depth [um]") ax1.set_title("Corrected peak depth") - ax2.plot(dp.temporal_bins, dp.motion, alpha=0.2, color="black") - ax2.plot(dp.temporal_bins, np.mean(dp.motion, axis=1), color="C0") + ax2.plot(temporal_bins, dp.motion, alpha=0.2, color="black") + ax2.plot(temporal_bins, np.mean(dp.motion, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) ax2.set_ylabel("motion [um]") ax2.set_title("Motion vectors") @@ -104,8 +105,8 @@ def do_plot(self, data_plot, **backend_kwargs): aspect="auto", origin="lower", extent=( - dp.temporal_bins[0], - dp.temporal_bins[-1], + temporal_bins[0], + temporal_bins[-1], dp.spatial_bins[0], dp.spatial_bins[-1], ), From 3b1e7b3934da8fefdeee47c6967c447e79e29cb1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 12 Jul 2023 15:04:31 +0200 Subject: [PATCH 043/166] use times for scatter too --- src/spikeinterface/widgets/matplotlib/motion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index 7556ec8674..61ba7b2346 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -40,16 +40,17 @@ def do_plot(self, data_plot, **backend_kwargs): temporal_bins = dp.temporal_bins if dp.times is None: times = np.arange(np.max(dp.peaks["sample_index"]) + 1) / dp.sampling_frequency + x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start times = dp.times temporal_bins += times[0] + x = times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, times, dp.motion, temporal_bins, dp.spatial_bins, direction="y" ) - x = dp.peaks["sample_index"] / dp.sampling_frequency y = dp.peak_locations["y"] y2 = corrected_location["y"] if dp.scatter_decimate is not None: From 7c70da9ddb1123e9b0d78b9a45c457bf80362a1f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 12 Jul 2023 15:05:48 +0200 Subject: [PATCH 044/166] motion -> Motion in y label --- src/spikeinterface/widgets/matplotlib/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index 61ba7b2346..c5572b8795 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -96,7 +96,7 @@ def do_plot(self, data_plot, **backend_kwargs): ax2.plot(temporal_bins, dp.motion, alpha=0.2, color="black") ax2.plot(temporal_bins, np.mean(dp.motion, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("motion [um]") + ax2.set_ylabel("Motion [um]") ax2.set_title("Motion vectors") axes = [ax0, ax1, ax2] From 5fc4fbc9f5be6fa4a77b0c24241551cf705a5ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 12 Jul 2023 17:27:02 +0200 Subject: [PATCH 045/166] Add option `relative_to=True` Add thte option to set `relative_to` to `True`, in which case it will be relative to the folder it's in. --- src/spikeinterface/core/base.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9b300e4787..4711a4344b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -531,10 +531,15 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non ---------- file_path: str Path of the json file - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path + relative_to: str, Path, True or None + If not None, file_path is serialized relative to this path + If True, file_path is serialized relative to the folder of file_path """ assert self.check_if_dumpable() + + if relative_to is True: + relative_to = Path(file_path).parent + dump_dict = self.to_dict( include_annotations=True, include_properties=False, relative_to=relative_to, folder_metadata=folder_metadata ) @@ -563,12 +568,17 @@ def dump_to_pickle( Path of the json file include_properties: bool If True, all properties are dumped - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path + relative_to: str, Path, True or None + If not None, file_path is serialized relative to this path + If True, file_path is serialized relative to the folder of file_path recursive: bool If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. """ assert self.check_if_dumpable() + + if relative_to is True: + relative_to = Path(file_path).parent + dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, @@ -591,6 +601,9 @@ def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor": """ file_path = Path(file_path) + if base_folder is True: + base_folder = file_path.parent + if file_path.is_file(): # standard case based on a file (json or pickle) if str(file_path).endswith(".json"): From 3a62100ba98dde2d2e10ddb7eba1e52d58fe6e87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 17 Jul 2023 09:27:20 +0200 Subject: [PATCH 046/166] Update doc and tests for `relative_to=True` --- src/spikeinterface/core/base.py | 13 +++++++------ src/spikeinterface/core/tests/test_baserecording.py | 6 ++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 94cab592d3..6c21f4fdca 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -547,8 +547,9 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No ---------- file_path: str or Path The output file (either .json or .pkl/.pickle) - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path + relative_to: str, Path, True or None + If not None, file_path is serialized relative to this path. If True, file_path is serialized relative to the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) @@ -567,8 +568,8 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non file_path: str Path of the json file relative_to: str, Path, True or None - If not None, file_path is serialized relative to this path - If True, file_path is serialized relative to the folder of file_path + If not None, file_path is serialized relative to this path. If True, file_path is serialized relative to the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ assert self.check_if_dumpable() @@ -604,8 +605,8 @@ def dump_to_pickle( include_properties: bool If True, all properties are dumped relative_to: str, Path, True or None - If not None, file_path is serialized relative to this path - If True, file_path is serialized relative to the folder of file_path + If not None, file_path is serialized relative to this path. If True, file_path is serialized relative to the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. recursive: bool If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. """ diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index ed9a79d055..9fbd158341 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -115,6 +115,12 @@ def test_BaseRecording(): rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) + # dump/load relative=True + + rec.dump_to_json(cache_folder / "test_BaseRecording_rel_true.json", relative_to=True) + rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) + rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) + # cache to binary folder = cache_folder / "simple_recording" rec.save(format="binary", folder=folder) From 96f784c392f09d210004da6eecdaf167521f5812 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 10:37:50 +0200 Subject: [PATCH 047/166] Expose AUCpslit param in KS2+ --- src/spikeinterface/sorters/external/kilosort2.py | 4 +++- src/spikeinterface/sorters/external/kilosort2_5.py | 4 +++- src/spikeinterface/sorters/external/kilosort3.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort2.py b/src/spikeinterface/sorters/external/kilosort2.py index 8a6998db92..267ff38e36 100644 --- a/src/spikeinterface/sorters/external/kilosort2.py +++ b/src/spikeinterface/sorters/external/kilosort2.py @@ -44,6 +44,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "ntbuff": 64, "nfilt_factor": 4, "NT": None, + "AUCsplit": 0.9, "wave_length": 61, "keep_good_only": False, "skip_kilosort_preprocessing": False, @@ -66,6 +67,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter): "ntbuff": "Samples of symmetrical buffer for whitening and spike detection", "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "NT": "Batch size (if None it is automatically computed)", + "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", @@ -161,7 +163,7 @@ def _get_specific_options(cls, ops, params): ops["lam"] = 10.0 # splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) - ops["AUCsplit"] = 0.9 + ops["AUCsplit"] = params["AUCsplit"] # minimum spike rate (Hz), if a cluster falls below this for too long it gets removed ops["minFR"] = params["minFR"] diff --git a/src/spikeinterface/sorters/external/kilosort2_5.py b/src/spikeinterface/sorters/external/kilosort2_5.py index ced2bd05ab..0c9e36177e 100644 --- a/src/spikeinterface/sorters/external/kilosort2_5.py +++ b/src/spikeinterface/sorters/external/kilosort2_5.py @@ -50,6 +50,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "ntbuff": 64, "nfilt_factor": 4, "NT": None, + "AUCsplit": 0.9, "do_correction": True, "wave_length": 61, "keep_good_only": False, @@ -76,6 +77,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter): "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", "NT": "Batch size (if None it is automatically computed)", + "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "keep_good_only": "If True only 'good' units are returned", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", @@ -182,7 +184,7 @@ def _get_specific_options(cls, ops, params): ops["lam"] = 10.0 # splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) - ops["AUCsplit"] = 0.9 + ops["AUCsplit"] = params["AUCsplit"] # minimum spike rate (Hz), if a cluster falls below this for too long it gets removed ops["minFR"] = params["minFR"] diff --git a/src/spikeinterface/sorters/external/kilosort3.py b/src/spikeinterface/sorters/external/kilosort3.py index c514480896..77e83e35b9 100644 --- a/src/spikeinterface/sorters/external/kilosort3.py +++ b/src/spikeinterface/sorters/external/kilosort3.py @@ -48,6 +48,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "nfilt_factor": 4, "do_correction": True, "NT": None, + "AUCsplit": 0.8, "wave_length": 61, "keep_good_only": False, "skip_kilosort_preprocessing": False, @@ -73,6 +74,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter): "nfilt_factor": "Max number of clusters per good channel (even temporary ones) 4", "do_correction": "If True drift registration is applied", "NT": "Batch size (if None it is automatically computed)", + "AUCsplit": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "keep_good_only": "If True only 'good' units are returned", "skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing", @@ -171,7 +173,7 @@ def _get_specific_options(cls, ops, params): ops["lam"] = 20.0 # splitting a cluster at the end requires at least this much isolation for each sub-cluster (max = 1) - ops["AUCsplit"] = 0.8 + ops["AUCsplit"] = params["AUCsplit"] # minimum firing rate on a "good" channel (0 to skip) ops["minfr_goodchannels"] = params["minfr_goodchannels"] From 3804f3028dcbd29bcd67095525eb465d4bf5e607 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 10:47:57 +0200 Subject: [PATCH 048/166] Improve docstrings --- src/spikeinterface/core/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6c21f4fdca..9e6766c056 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -322,7 +322,7 @@ def to_dict( include_properties: bool If True, all properties are added to the dict, by default False relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path, by default None + If not None, files and folders are serialized relative to this path, by default None Used in waveform extractor to maintain relative paths to binary files even if the containing folder / diretory is moved folder_metadata: str, Path, or None @@ -548,7 +548,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No file_path: str or Path The output file (either .json or .pkl/.pickle) relative_to: str, Path, True or None - If not None, file_path is serialized relative to this path. If True, file_path is serialized relative to the parent folder. + If not None, files and folders is serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ if str(file_path).endswith(".json"): @@ -568,7 +568,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non file_path: str Path of the json file relative_to: str, Path, True or None - If not None, file_path is serialized relative to this path. If True, file_path is serialized relative to the parent folder. + If not None, files and folders is serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ assert self.check_if_dumpable() @@ -605,7 +605,7 @@ def dump_to_pickle( include_properties: bool If True, all properties are dumped relative_to: str, Path, True or None - If not None, file_path is serialized relative to this path. If True, file_path is serialized relative to the parent folder. + If not None, files and folders is serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. recursive: bool If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. From a32d64e9e788d63c48505b06310759c24b75f129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 17 Jul 2023 11:58:46 +0200 Subject: [PATCH 049/166] Improved tests for `relative_to=True` --- src/spikeinterface/core/tests/test_baserecording.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 9fbd158341..b34d09f133 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -2,6 +2,7 @@ test for BaseRecording are done with BinaryRecordingExtractor. but check only for BaseRecording general methods. """ +import json import shutil from pathlib import Path import pytest @@ -116,10 +117,14 @@ def test_BaseRecording(): rec3 = load_extractor(cache_folder / "test_BaseRecording_rel.json", base_folder=cache_folder) # dump/load relative=True - rec.dump_to_json(cache_folder / "test_BaseRecording_rel_true.json", relative_to=True) rec2 = BaseExtractor.load(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) rec3 = load_extractor(cache_folder / "test_BaseRecording_rel_true.json", base_folder=True) + check_recordings_equal(rec, rec2, return_scaled=False, check_annotations=True) + check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) + with open(cache_folder / "test_BaseRecording_rel_true.json") as json_file: + data = json.load(json_file) + assert '/' not in data["kwargs"]["file_paths"][0] # Relative to parent folder, so there shouldn't be any '/' in the path. # cache to binary folder = cache_folder / "simple_recording" From 9105f15a8b9d0b4a9011f01ad931166d9a9db514 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 11:59:16 +0200 Subject: [PATCH 050/166] wip --- src/spikeinterface/core/base.py | 26 +++++++++++++++--------- src/spikeinterface/sorters/basesorter.py | 2 +- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3925f41d2b..47383267ec 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -338,6 +338,9 @@ def to_dict( kwargs = self._kwargs + if relative_to and not recursive: + raise Exception("`relative_to` is only posible when `recursive=True`") + if recursive: to_dict_kwargs = dict( include_annotations=include_annotations, @@ -560,7 +563,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None: """ Dump recording extractor to json file. - The extractor can be re-loaded with load_extractor_from_json(json_file) + The extractor can be re-loaded with load_extractor(json_file) Parameters ---------- @@ -568,10 +571,16 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non Path of the json file relative_to: str, Path, or None If not None, file_paths are serialized relative to this path + folder_metadata: str, Path, or None + Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable() + assert self.check_if_json_serializable(), "The extractor is not json serializable" dump_dict = self.to_dict( - include_annotations=True, include_properties=False, relative_to=relative_to, folder_metadata=folder_metadata + include_annotations=True, + include_properties=False, + relative_to=relative_to, + folder_metadata=folder_metadata, + recursive=True, ) file_path = self._get_file_path(file_path, [".json"]) @@ -584,13 +593,11 @@ def dump_to_pickle( self, file_path: Union[str, Path, None] = None, include_properties: bool = True, - relative_to=None, folder_metadata=None, - recursive: bool = False, ): """ Dump recording extractor to a pickle file. - The extractor can be re-loaded with load_extractor_from_json(json_file) + The extractor can be re-loaded with load_extractor(pickle_file) Parameters ---------- @@ -600,16 +607,15 @@ def dump_to_pickle( If True, all properties are dumped relative_to: str, Path, or None If not None, file_paths are serialized relative to this path - recursive: bool - If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. + folder_metadata: str, Path, or None + Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ assert self.check_if_dumpable() dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, - relative_to=relative_to, folder_metadata=folder_metadata, - recursive=recursive, + recursive=False, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index cbaba31d02..01b7b46703 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -141,7 +141,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo rec_file = output_folder / "spikeinterface_recording.json" if recording.check_if_json_serializable(): - recording.dump_to_json(rec_file, relative_to=output_folder) + recording.dump_to_json(rec_file, relative_to=output_folder, recursive=True) else: d = {"warning": "The recording is not rerializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") From 4db6c3f6adbd326fb08cfa8c9d7a2e19313ece48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jul 2023 09:59:55 +0000 Subject: [PATCH 051/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_baserecording.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index b34d09f133..0f5a60a047 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -124,7 +124,9 @@ def test_BaseRecording(): check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True) with open(cache_folder / "test_BaseRecording_rel_true.json") as json_file: data = json.load(json_file) - assert '/' not in data["kwargs"]["file_paths"][0] # Relative to parent folder, so there shouldn't be any '/' in the path. + assert ( + "/" not in data["kwargs"]["file_paths"][0] + ) # Relative to parent folder, so there shouldn't be any '/' in the path. # cache to binary folder = cache_folder / "simple_recording" From 999b662588760aaa64544d47ca670ed253550b50 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 12:04:45 +0200 Subject: [PATCH 052/166] Add docs to load and protect dict --- src/spikeinterface/core/base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9e6766c056..724b6bf6ac 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -548,7 +548,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No file_path: str or Path The output file (either .json or .pkl/.pickle) relative_to: str, Path, True or None - If not None, files and folders is serialized relative to this path. If True, the relative folder is the parent folder. + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ if str(file_path).endswith(".json"): @@ -568,7 +568,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non file_path: str Path of the json file relative_to: str, Path, True or None - If not None, files and folders is serialized relative to this path. If True, the relative folder is the parent folder. + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ assert self.check_if_dumpable() @@ -627,13 +627,14 @@ def dump_to_pickle( file_path.write_bytes(pickle.dumps(dump_dict)) @staticmethod - def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor": + def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, bool]] = None) -> "BaseExtractor": """ Load extractor from file path (.json or .pkl) Used both after: * dump(...) json or pickle file * save (...) a folder which contain data + json (or pickle) + metadata. + """ file_path = Path(file_path) @@ -1044,7 +1045,11 @@ def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: Parameters ---------- - file_or_folder_or_dict: dictionary or folder or file (json, pickle) + file_or_folder_or_dict : dictionary or folder or file (json, pickle) + The file path, folder path, or dictionary to load the extractor from + base_folder : str | Path | bool (optional) + The base folder to make relative paths absolute. + If True and file_or_folder_or_dict is a file, the parent folder of the file is used. Returns ------- @@ -1052,6 +1057,7 @@ def load_extractor(file_or_folder_or_dict, base_folder=None) -> BaseExtractor: The loaded extractor object """ if isinstance(file_or_folder_or_dict, dict): + assert not isinstance(base_folder, bool), "`base_folder` must be a string or Path when loading from dict" return BaseExtractor.from_dict(file_or_folder_or_dict, base_folder=base_folder) else: return BaseExtractor.load(file_or_folder_or_dict, base_folder=base_folder) From de7df4866251aecfb53d7b8ea0b6480901ab6ea3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 12:49:27 +0200 Subject: [PATCH 053/166] Do not load NP probe in OE if load_sync_channel=True --- src/spikeinterface/extractors/neoextractors/openephys.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index e1a6598f61..a7af1078b4 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -149,8 +149,8 @@ def __init__( else: exp_id = exp_ids[block_index] - # do not load probe for NIDQ stream - if "NI-DAQmx" not in stream_name: + # do not load probe for NIDQ stream or if load_sync_channel is True + if "NI-DAQmx" not in stream_name and not load_sync_channel: settings_file = self.neo_reader.folder_structure[record_node]["experiments"][exp_id]["settings_file"] if Path(settings_file).is_file(): From 7272611ae919ad86096574b34096731923268a84 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 12:56:48 +0200 Subject: [PATCH 054/166] Oups --- src/spikeinterface/core/base.py | 12 ++---------- src/spikeinterface/core/tests/test_baserecording.py | 2 +- src/spikeinterface/core/tests/test_basesnippets.py | 2 +- src/spikeinterface/core/tests/test_basesorting.py | 2 +- src/spikeinterface/preprocessing/filter_gaussian.py | 2 +- 5 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index e58e47a15d..9b258e3876 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -976,24 +976,16 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: # Create new kwargs to avoid modifying the original dict["kwargs"] new_kwargs = dict() - transform_dict_to_extractor = ( - lambda x: _load_extractor_from_dict(x, preloaded_extractor_dict) if is_dict_extractor(x) else x - ) + transform_dict_to_extractor = lambda x: _load_extractor_from_dict(x) if is_dict_extractor(x) else x for name, value in dic["kwargs"].items(): if is_dict_extractor(value): - hash_value = hash(value) - if hash_value in preloaded_extractor_dict: - new_kwargs[name] = preloaded_extractor_dict[hash_value] - else: - new_kwargs[name] = _load_extractor_from_dict(value) + new_kwargs[name] = _load_extractor_from_dict(value) elif isinstance(value, dict): new_kwargs[name] = {k: transform_dict_to_extractor(v) for k, v in value.items()} elif isinstance(value, list): new_kwargs[name] = [transform_dict_to_extractor(e) for e in value] else: new_kwargs[name] = value - hash_value = hash(value) - preloaded_extractor_dict[hash_value] = value class_name = dic["class"] extractor_class = _get_class_from_string(class_name) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 0f5a60a047..38987a58e5 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -107,7 +107,7 @@ def test_BaseRecording(): check_recordings_equal(rec, rec3, return_scaled=False, check_annotations=True, check_properties=True) # dump/load dict - relative - d = rec.to_dict(relative_to=cache_folder) + d = rec.to_dict(relative_to=cache_folder, recursive=True) rec2 = BaseExtractor.from_dict(d, base_folder=cache_folder) rec3 = load_extractor(d, base_folder=cache_folder) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index d286a0dd37..d0699c892f 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -107,7 +107,7 @@ def test_BaseSnippets(): snippets3 = load_extractor(cache_folder / "test_BaseSnippets.pkl") # dump/load dict - relative - d = snippets.to_dict(relative_to=cache_folder) + d = snippets.to_dict(relative_to=cache_folder, recursive=True) snippets2 = BaseExtractor.from_dict(d, base_folder=cache_folder) snippets3 = load_extractor(d, base_folder=cache_folder) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6e471121b6..9214f4b0e4 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -134,7 +134,7 @@ def test_npy_sorting(): seg_nframes = [9, 5] rec = NumpyRecording([np.zeros((nframes, 10)) for nframes in seg_nframes], sampling_frequency=sfreq) # assert_raises(Exception, sorting.register_recording, rec) - with pytest.warns(): + with pytest.warns(UserWarning): sorting.register_recording(rec) # Registering a rec with too many segments diff --git a/src/spikeinterface/preprocessing/filter_gaussian.py b/src/spikeinterface/preprocessing/filter_gaussian.py index 56d43e13e8..79b5ba5bc3 100644 --- a/src/spikeinterface/preprocessing/filter_gaussian.py +++ b/src/spikeinterface/preprocessing/filter_gaussian.py @@ -40,7 +40,7 @@ def __init__(self, recording: BaseRecording, freq_min: float = 300.0, freq_max: for parent_segment in recording._recording_segments: self.add_recording_segment(GaussianFilterRecordingSegment(parent_segment, freq_min, freq_max)) - self._kwargs = {"recording": recording.to_dict(), "freq_min": freq_min, "freq_max": freq_max} + self._kwargs = {"recording": recording, "freq_min": freq_min, "freq_max": freq_max} class GaussianFilterRecordingSegment(BasePreprocessorSegment): From 2d3e4a21af1a80cb5228eaf3dd171dd882810e81 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 13:13:27 +0200 Subject: [PATCH 055/166] Use sampling_frequency for correcT_motion_on_peaks and simplify plotting --- doc/how_to/handle_drift.rst | 2 +- examples/how_to/handle_drift.py | 2 +- src/spikeinterface/preprocessing/motion.py | 3 ++- .../benchmark/benchmark_motion_estimation.py | 4 +--- .../sortingcomponents/motion_interpolation.py | 8 ++++--- .../tests/test_motion_interpolation.py | 3 +-- .../widgets/matplotlib/motion.py | 23 +++++++++++-------- src/spikeinterface/widgets/motion.py | 9 ++------ 8 files changed, 27 insertions(+), 27 deletions(-) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index a4fd13097d..53b68e8c17 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -272,7 +272,7 @@ to display the results. #color='black', ax.scatter(loc['x'][mask][sl], loc['y'][mask][sl], **color_kargs) - loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.get_times(), + loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.sampling_frequency, motion_info['motion'], motion_info['temporal_bins'], motion_info['spatial_bins'], direction="y") ax = axs[1] diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index 7f4a4e2db4..26841f49dd 100644 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -166,7 +166,7 @@ def preprocess_chain(rec): #color='black', ax.scatter(loc['x'][mask][sl], loc['y'][mask][sl], **color_kargs) - loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.get_times(), + loc2 = correct_motion_on_peaks(motion_info['peaks'], motion_info['peak_locations'], rec.sampling_frequency, motion_info['motion'], motion_info['temporal_bins'], motion_info['spatial_bins'], direction="y") ax = axs[1] diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 957d4f588e..00d51e7476 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -258,7 +258,7 @@ def correct_motion( noise_levels = get_noise_levels(recording, return_scaled=False) if select_kwargs is None: - # maybe do this directly in the folderwhen not None + # maybe do this directly in the folder when not None gather_mode = "memory" # node detect @@ -328,6 +328,7 @@ def correct_motion( estimate_motion_kwargs=estimate_motion_kwargs, interpolate_motion_kwargs=interpolate_motion_kwargs, job_kwargs=job_kwargs, + sampling_frequency=recording.sampling_frequency, ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index bf3577368e..dd35670abd 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -337,12 +337,10 @@ def plot_motion_corrected_peaks(self, scaling_probe=1.5, alpha=0.05, figsize=(15 channel_positions = self.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() - times = self.recording.get_times() - peak_locations_corrected = correct_motion_on_peaks( self.selected_peaks, self.peak_locations, - times, + self.recording.sampling_frequency, self.motion, self.temporal_bins, self.spatial_bins, diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 2fa23ee98d..59cee00994 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -19,7 +19,7 @@ def correct_motion_on_peaks( peaks, peak_locations, - times, + sampling_frequency, motion, temporal_bins, spatial_bins, @@ -34,8 +34,8 @@ def correct_motion_on_peaks( peaks vector peak_locations: np.array peaks location vector - times: np.array - times vector of recording + sampling_frequency: np.array + sampling_frequency of the recording motion: np.array 2D motion.shape[0] equal temporal_bins.shape[0] motion.shape[1] equal 1 when "rigid" motion equal temporal_bins.shape[0] when "non-rigid" @@ -49,6 +49,8 @@ def correct_motion_on_peaks( corrected_peak_locations: np.array Motion-corrected peak locations """ + # make linear times + times = np.arange(np.max(peaks["sample_index"]) + 1) / sampling_frequency corrected_peak_locations = peak_locations.copy() if spatial_bins.shape[0] == 1: diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index b25cea69a6..b7ab67350e 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -44,12 +44,11 @@ def test_correct_motion_on_peaks(): # fake locations peak_locations = np.zeros((peaks.size), dtype=[("x", "float32"), ("y", "float")]) - times = rec.get_times() corrected_peak_locations = correct_motion_on_peaks( peaks, peak_locations, - times, + rec.sampling_frequency, motion, temporal_bins, spatial_bins, diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index c5572b8795..c4f32e4e75 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -37,18 +37,23 @@ def do_plot(self, data_plot, **backend_kwargs): else: motion_lim = dp.motion_lim - temporal_bins = dp.temporal_bins if dp.times is None: - times = np.arange(np.max(dp.peaks["sample_index"]) + 1) / dp.sampling_frequency + temporal_bins_plot = dp.temporal_bins x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start - times = dp.times - temporal_bins += times[0] + temporal_bins_plot = dp.temporal_bins + times[0] x = times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( - dp.peaks, dp.peak_locations, times, dp.motion, temporal_bins, dp.spatial_bins, direction="y" + dp.peaks, + dp.peak_locations, + dp, + sampling_frequency, + dp.motion, + dp.temporal_bins, + dp.spatial_bins, + direction="y", ) y = dp.peak_locations["y"] @@ -93,8 +98,8 @@ def do_plot(self, data_plot, **backend_kwargs): ax1.set_ylabel("Depth [um]") ax1.set_title("Corrected peak depth") - ax2.plot(temporal_bins, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins, np.mean(dp.motion, axis=1), color="C0") + ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") + ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) ax2.set_ylabel("Motion [um]") ax2.set_title("Motion vectors") @@ -106,8 +111,8 @@ def do_plot(self, data_plot, **backend_kwargs): aspect="auto", origin="lower", extent=( - temporal_bins[0], - temporal_bins[-1], + temporal_bins_plot[0], + temporal_bins_plot[-1], dp.spatial_bins[0], dp.spatial_bins[-1], ), diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 6feb597d46..82e9be2407 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -17,7 +17,7 @@ class MotionWidget(BaseWidget): motion_info: dict The motion info return by correct_motion() or load back with load_motion_info() recording : RecordingExtractor, optional - The recording extractor object (used to get sampling frequency and times), default None) + The recording extractor object (only used to get "real" times), default None sampling_frequency : float, optional The sampling frequency (needed if recording is None), default None depth_lim : tuple @@ -42,7 +42,6 @@ def __init__( self, motion_info, recording=None, - sampling_frequency=None, depth_lim=None, motion_lim=None, color_amplitude=False, @@ -53,14 +52,10 @@ def __init__( backend=None, **backend_kwargs, ): - assert recording or sampling_frequency, "recording or sampling_frequency must be provided" - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times() if recording is not None else None plot_data = dict( - sampling_frequency=sampling_frequency, + sampling_frequency=motion_info["parameters"]["sampling_frequency"], times=times, depth_lim=depth_lim, motion_lim=motion_lim, From a3dcdb577106c2c959dd8d9621541bfa4bbd7c73 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 13:16:44 +0200 Subject: [PATCH 056/166] Wrong kwarg --- src/spikeinterface/comparison/multicomparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 1b29d432a9..ed9ed7520c 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -211,7 +211,7 @@ def load_from_folder(folder_path): with (folder_path / "sortings.json").open() as f: dict_sortings = json.load(f) name_list = list(dict_sortings.keys()) - sorting_list = [load_extractor(v, base_path=folder_path) for v in dict_sortings.values()] + sorting_list = [load_extractor(v, base_folder=folder_path) for v in dict_sortings.values()] mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list(name_list), do_matching=False, **kwargs) filename = str(folder_path / "multicomparison.gpickle") with open(filename, "rb") as f: From 5c9838e2bea5604c23e27ba85cf354e7fb9d5b1e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 16:02:31 +0200 Subject: [PATCH 057/166] Ramon's suggestions --- src/spikeinterface/core/base.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 9b258e3876..484b159756 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -339,7 +339,7 @@ def to_dict( kwargs = self._kwargs if relative_to and not recursive: - raise Exception("`relative_to` is only posible when `recursive=True`") + raise ValueError("`relative_to` is only posible when `recursive=True`") if recursive: to_dict_kwargs = dict( @@ -570,15 +570,19 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non ---------- file_path: str Path of the json file - relative_to: str, Path, or None - If not None, file_paths are serialized relative to this path + relative_to: str, Path, True or None + If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. + This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ assert self.check_if_json_serializable(), "The extractor is not json serializable" - if relative_to is True: - relative_to = Path(file_path).parent + # Writing paths as relative_to requires recursively expanding the dict + if relative_to: + recursive = True + # We use relative_to == True to encode using the parent_folder + relative_to = Path(file_path).parent if relative_to is True else relative_to dump_dict = self.to_dict( include_annotations=True, From a665c283557b15bdc54ced7e0de2e89c4a331d2d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 16:20:30 +0200 Subject: [PATCH 058/166] Make all paths resolved and absolute --- src/spikeinterface/core/base.py | 24 +++++++++++-------- src/spikeinterface/core/binaryfolder.py | 2 +- .../core/binaryrecordingextractor.py | 2 +- src/spikeinterface/core/npyfoldersnippets.py | 2 +- .../core/npysnippetsextractor.py | 2 +- src/spikeinterface/core/npzfolder.py | 2 +- .../core/zarrrecordingextractor.py | 2 +- src/spikeinterface/exporters/to_phy.py | 2 +- src/spikeinterface/extractors/cbin_ibl.py | 5 +++- .../cellexplorersortingextractor.py | 2 +- .../extractors/combinatoextractors.py | 4 ++-- .../extractors/hdsortextractors.py | 2 +- .../extractors/herdingspikesextractors.py | 2 +- .../extractors/matlabhelpers.py | 2 +- .../extractors/mdaextractors.py | 7 ++++-- .../extractors/neoextractors/alphaomega.py | 6 ++--- .../extractors/neoextractors/axona.py | 2 +- .../extractors/neoextractors/biocam.py | 10 ++++++-- .../extractors/neoextractors/blackrock.py | 6 ++--- .../extractors/neoextractors/ced.py | 4 ++-- .../extractors/neoextractors/edf.py | 6 ++--- .../extractors/neoextractors/intan.py | 4 ++-- .../extractors/neoextractors/maxwell.py | 4 ++-- .../extractors/neoextractors/mcsraw.py | 4 ++-- .../extractors/neoextractors/mearec.py | 16 +++++++++---- .../extractors/neoextractors/neuralynx.py | 6 ++--- .../extractors/neoextractors/neuroscope.py | 2 +- .../extractors/neoextractors/nix.py | 4 ++-- .../extractors/neoextractors/openephys.py | 10 ++++---- .../extractors/neoextractors/plexon.py | 8 +++---- .../extractors/neoextractors/spike2.py | 4 ++-- .../extractors/neoextractors/spikegadgets.py | 4 ++-- .../extractors/neoextractors/spikeglx.py | 6 +++-- .../extractors/neoextractors/tdt.py | 4 ++-- .../extractors/nwbextractors.py | 12 +++++----- .../extractors/shybridextractors.py | 4 ++-- .../extractors/spykingcircusextractors.py | 2 +- .../extractors/yassextractors.py | 2 +- src/spikeinterface/sorters/basesorter.py | 4 ++-- 39 files changed, 111 insertions(+), 85 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 724b6bf6ac..e6f70bf05e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -394,13 +394,13 @@ def to_dict( dump_dict["properties"] = {k: self._properties.get(k, None) for k in self._main_properties} if relative_to is not None: - relative_to = Path(relative_to).absolute() + relative_to = Path(relative_to).resolve().absolute() assert relative_to.is_dir(), "'relative_to' must be an existing directory" dump_dict = _make_paths_relative(dump_dict, relative_to) if folder_metadata is not None: if relative_to is not None: - folder_metadata = Path(folder_metadata).absolute().relative_to(relative_to) + folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) return dump_dict @@ -533,7 +533,7 @@ def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: file_path.parent.mkdir(parents=True, exist_ok=True) folder_path = file_path.parent if Path(file_path).suffix == "": - file_path = folder_path / (str(file_path) + ext) + file_path = folder_path / (str(Path(file_path).resolve().absolute()) + ext) assert file_path.suffix in extensions, "'file_path' should have one of the following extensions:" " %s" % ( ", ".join(extensions) ) @@ -551,9 +551,11 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ - if str(file_path).endswith(".json"): + if str(Path(file_path).resolve().absolute()).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) - elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): + elif str(Path(file_path).resolve().absolute()).endswith(".pkl") or str( + Path(file_path).resolve().absolute() + ).endswith(".pickle"): self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -643,11 +645,13 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo if file_path.is_file(): # standard case based on a file (json or pickle) - if str(file_path).endswith(".json"): - with open(str(file_path), "r") as f: + if str(Path(file_path).resolve().absolute()).endswith(".json"): + with open(str(Path(file_path).resolve().absolute()), "r") as f: d = json.load(f) - elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - with open(str(file_path), "rb") as f: + elif str(Path(file_path).resolve().absolute()).endswith(".pkl") or str( + Path(file_path).resolve().absolute() + ).endswith(".pickle"): + with open(str(Path(file_path).resolve().absolute()), "rb") as f: d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") @@ -936,7 +940,7 @@ def save_to_zarr( def _make_paths_relative(d, relative) -> dict: - relative = str(Path(relative).absolute()) + relative = str(Path(relative).resolve().absolute()) func = lambda p: os.path.relpath(str(p), start=relative) return recursive_path_modifier(d, func, target="path", copy=True) diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index d9a4ce0963..1a95d4b2bb 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).resolve().absolute())) self._bin_kwargs = d["kwargs"] if "num_channels" not in self._bin_kwargs: assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json" diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index b5c1d2c888..c41ea1e095 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -116,7 +116,7 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(e.absolute()) for e in file_path_list], + "file_paths": [str(Path(e).resolve().absolute()) for e in file_path_list], "sampling_frequency": sampling_frequency, "t_starts": t_starts, "num_channels": num_channels, diff --git a/src/spikeinterface/core/npyfoldersnippets.py b/src/spikeinterface/core/npyfoldersnippets.py index b7c773aad3..04d01954fb 100644 --- a/src/spikeinterface/core/npyfoldersnippets.py +++ b/src/spikeinterface/core/npyfoldersnippets.py @@ -48,7 +48,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).resolve().absolute())) self._bin_kwargs = d["kwargs"] diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index f534592624..12592dfee8 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -47,7 +47,7 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(f) for f in file_paths], + "file_paths": [str(Path(f).resolve().absolute()) for f in file_paths], "sampling_frequency": sampling_frequency, "channel_ids": channel_ids, "nbefore": nbefore, diff --git a/src/spikeinterface/core/npzfolder.py b/src/spikeinterface/core/npzfolder.py index 9d2eb43af6..0d79177dd2 100644 --- a/src/spikeinterface/core/npzfolder.py +++ b/src/spikeinterface/core/npzfolder.py @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).resolve().absolute())) self._npz_kwargs = d["kwargs"] diff --git a/src/spikeinterface/core/zarrrecordingextractor.py b/src/spikeinterface/core/zarrrecordingextractor.py index afa27da905..5197e0fcc8 100644 --- a/src/spikeinterface/core/zarrrecordingextractor.py +++ b/src/spikeinterface/core/zarrrecordingextractor.py @@ -49,7 +49,7 @@ def __init__(self, root_path: Union[Path, str], storage_options=None): root_path = Path(root_path) else: root_path_init = str(root_path) - root_path_kwarg = str(root_path.absolute()) + root_path_kwarg = str(root_path.resolve().absolute()) else: root_path_init = root_path root_path_kwarg = root_path_init diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 58804037b0..df1a00471f 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -111,7 +111,7 @@ def export_to_phy( if len(unit_ids) == 0: raise Exception("No non-empty units in the sorting result, can't save to Phy.") - output_folder = Path(output_folder).absolute() + output_folder = Path(output_folder).resolve().absolute() if output_folder.is_dir(): if remove_if_exists: shutil.rmtree(output_folder) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index dda04fbb17..926009cb1c 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -106,7 +106,10 @@ def __init__(self, folder_path, load_sync_channel=False): sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc) self.set_property("inter_sample_shift", sample_shifts) - self._kwargs = {"folder_path": str(folder_path.absolute()), "load_sync_channel": load_sync_channel} + self._kwargs = { + "folder_path": str(Path(folder_path).resolve.absolute()), + "load_sync_channel": load_sync_channel, + } class CBinIblRecordingSegment(BaseRecordingSegment): diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index b40b998103..b9a4c6c576 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -203,7 +203,7 @@ def _retrieve_sampling_frequency_from_session_info_file(self) -> float: if self.session_info_file_path is None: self.session_info_file_path = self.session_path / f"{self.session_id}.sessionInfo.mat" - self.session_info_file_path = Path(self.session_info_file_path).absolute() + self.session_info_file_path = Path(self.session_info_file_path).resolve().absolute() assert ( self.session_info_file_path.is_file() ), f"No {self.session_id}.sessionInfo.mat file found in the {self.session_path}!, can't inferr sampling rate, please pass the sampling rate at initialization" diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index 5e17fd3045..7a682bc4f1 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -44,7 +44,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign folder_path = Path(folder_path) assert folder_path.is_dir(), "Folder {} doesn't exist".format(folder_path) if sampling_frequency is None: - h5_path = str(folder_path) + ".h5" + h5_path = str(Path(folder_path).resolve().absolute()) + ".h5" if Path(h5_path).exists(): with h5py.File(h5_path, mode="r") as f: sampling_frequency = f["sr"][0] @@ -85,7 +85,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign self.add_sorting_segment(CombinatoSortingSegment(spiketrains)) self.set_property("unsorted", np.array([metadata[u]["group_type"] == 0 for u in range(unit_counter)])) self.set_property("artifact", np.array([metadata[u]["group_type"] == -1 for u in range(unit_counter)])) - self._kwargs = {"folder_path": str(folder_path), "user": user, "det_sign": det_sign} + self._kwargs = {"folder_path": str(Path(folder_path).resolve().absolute()), "user": user, "det_sign": det_sign} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 6b904f812b..3906fb8457 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -108,7 +108,7 @@ def __init__(self, file_path, keep_good_only=True): self.set_property("template", np.array(templates)) self.set_property("template_frames_cut_before", np.array(templates_frames_cut_before)) - self._kwargs = {"file_path": str(file_path), "keep_good_only": keep_good_only} + self._kwargs = {"file_path": str(Path(file_path).resolve().absolute()), "keep_good_only": keep_good_only} # TODO features # ~ for uc, unit in enumerate(units): diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 1fc71b1cd0..695eba9750 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -57,7 +57,7 @@ def __init__(self, file_path, load_unit_info=True): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) - self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info} + self._kwargs = {"file_path": str(Path(file_path).resolve().absolute()), "load_unit_info": load_unit_info} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/matlabhelpers.py b/src/spikeinterface/extractors/matlabhelpers.py index 46bcf2d88c..7b61ed17cf 100644 --- a/src/spikeinterface/extractors/matlabhelpers.py +++ b/src/spikeinterface/extractors/matlabhelpers.py @@ -26,7 +26,7 @@ def __init__(self, file_path): if not file_path.is_file(): raise ValueError(f"Specified file path '{file_path}' is not a file.") - self._kwargs = {"file_path": str(file_path.absolute())} + self._kwargs = {"file_path": str(Path(file_path).resolve().absolute())} try: # load old-style (up to 7.2) .mat file self._data = loadmat(file_path, matlab_compatible=True) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 68317e25be..3adf9fcb62 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -196,7 +196,7 @@ class MdaSortingExtractor(BaseSorting): name = "mda" def __init__(self, file_path, sampling_frequency): - firings = readmda(str(file_path)) + firings = readmda(str(Path(file_path).resolve().absolute())) labels = firings[2, :] unit_ids = np.unique(labels).astype(int) BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) @@ -204,7 +204,10 @@ def __init__(self, file_path, sampling_frequency): sorting_segment = MdaSortingSegment(firings) self.add_sorting_segment(sorting_segment) - self._kwargs = {"file_path": str(Path(file_path).absolute()), "sampling_frequency": sampling_frequency} + self._kwargs = { + "file_path": str(Path(file_path).resolve().absolute()), + "sampling_frequency": sampling_frequency, + } @staticmethod def write_sorting(sorting, save_path, write_primary_channels=False): diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 78844a5267..e546b9d971 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -32,12 +32,12 @@ def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=Non NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(folder_path=str(folder_path), lsx_files=lsx_files)) + self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()), lsx_files=lsx_files)) @classmethod def map_to_neo_kwargs(cls, folder_path, lsx_files=None): neo_kwargs = { - "dirname": str(folder_path), + "dirname": str(Path(folder_path).resolve().absolute()), "lsx_files": lsx_files, } return neo_kwargs @@ -58,7 +58,7 @@ def __init__(self, folder_path): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index cb0cf19ff8..f8fb4a6733 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -28,7 +28,7 @@ def __init__(self, file_path, all_annotations=False): @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 23fcb2c419..1886317ffc 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -57,11 +57,17 @@ def __init__( self.set_property("row", self.get_property("contact_vector")["row"]) self.set_property("col", self.get_property("contact_vector")["col"]) - self._kwargs.update({"file_path": str(file_path), "mea_pitch": mea_pitch, "electrode_width": electrode_width}) + self._kwargs.update( + { + "file_path": str(Path(file_path).resolve().absolute()), + "mea_pitch": mea_pitch, + "electrode_width": electrode_width, + } + ) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 5408173a12..98ddf4204b 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -56,11 +56,11 @@ def __init__( use_names_as_ids=use_names_as_ids, **neo_kwargs, ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs @@ -115,7 +115,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 86865f312b..c5d83d25e7 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -34,12 +34,12 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(file_path))) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()))) self.extra_requirements.append("neo[ced]") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index b50df7868c..a27423e523 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -27,16 +27,16 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): name = "edf" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) self.extra_requirements.append("neo[edf]") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 1cbd5bd869..a456052668 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -36,11 +36,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(file_path))) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()))) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 7010b55721..0a65e2e20a 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -70,11 +70,11 @@ def __init__( probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) self.set_property("electrode", self.get_property("contact_vector")["electrode"]) - self._kwargs.update(dict(file_path=str(file_path), rec_name=rec_name)) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), rec_name=rec_name)) @classmethod def map_to_neo_kwargs(cls, file_path, rec_name=None): - neo_kwargs = {"filename": str(file_path), "rec_name": rec_name} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute()), "rec_name": rec_name} return neo_kwargs def install_maxwell_plugin(self, force_download=False): diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 6e377ea799..5a8ca1598d 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -40,11 +40,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(file_path))) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()))) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index ab24034b9a..82e2f2b858 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -40,14 +40,18 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): if hasattr(self.neo_reader._recgen, "gain_to_uV"): self.set_channel_gains(self.neo_reader._recgen.gain_to_uV) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) @classmethod def map_to_neo_kwargs( cls, file_path, ): - neo_kwargs = {"filename": str(file_path), "load_spiketrains": False, "load_analogsignal": True} + neo_kwargs = { + "filename": str(Path(file_path).resolve().absolute()), + "load_spiketrains": False, + "load_analogsignal": True, + } return neo_kwargs @@ -63,11 +67,15 @@ def __init__(self, file_path: Union[str, Path]): sampling_frequency = self.read_sampling_frequency(file_path=file_path) NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, use_format_ids=True, **neo_kwargs) - self._kwargs = {"file_path": str(file_path)} + self._kwargs = {"file_path": str(Path(file_path).resolve().absolute())} @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path), "load_spiketrains": True, "load_analogsignal": False} + neo_kwargs = { + "filename": str(Path(file_path).resolve().absolute()), + "load_spiketrains": True, + "load_analogsignal": False, + } return neo_kwargs def read_sampling_frequency(self, file_path: Union[str, Path]) -> float: diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 58e97a69ef..b9e98692a6 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -32,11 +32,11 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, all_annotation NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(folder_path=str(folder_path))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} return neo_kwargs @@ -90,7 +90,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 49c194ce92..fbea66d62c 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -51,7 +51,7 @@ def __init__(self, file_path, xml_file_path=None, stream_id=None, stream_name=No NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(file_path), xml_file_path=xml_file_path)) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), xml_file_path=xml_file_path)) @classmethod def map_to_neo_kwargs(cls, file_path, xml_file_path=None): diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 8781a4df71..9da7323421 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -37,12 +37,12 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(file_path), stream_id=stream_id)) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), stream_id=stream_id)) self.extra_requirements.append("neo[nixio]") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index e1a6598f61..ea0fbfad9b 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -61,11 +61,11 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(folder_path=str(folder_path))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} return neo_kwargs @@ -204,7 +204,7 @@ def __init__( self._kwargs.update( dict( - folder_path=str(folder_path), + folder_path=str(Path(folder_path).resolve().absolute()), load_sync_channel=load_sync_channel, load_sync_timestamps=load_sync_timestamps, experiment_names=experiment_names, @@ -214,7 +214,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False, experiment_names=None): neo_kwargs = { - "dirname": str(folder_path), + "dirname": str(Path(folder_path).resolve().absolute()), "load_sync_channel": load_sync_channel, "experiment_names": experiment_names, } @@ -248,7 +248,7 @@ def __init__(self, folder_path, block_index=None): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 84c06e6974..9fee26ef9c 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -30,11 +30,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs @@ -61,11 +61,11 @@ def __init__(self, file_path): self.neo_reader.parse_header() sampling_frequency = self.neo_reader._global_ssampling_rate NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) - self._kwargs = {"file_path": str(file_path)} + self._kwargs = {"file_path": str(Path(file_path).resolve().absolute())} @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 7fab8e4087..aaa6c198f8 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -31,12 +31,12 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) self.extra_requirements.append("sonpy") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index e5841e9df8..02d0c92511 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -30,11 +30,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(file_path), stream_id=stream_id)) + self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), stream_id=stream_id)) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 1998995cb4..a9a0dd0000 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -90,11 +90,13 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ self.set_property("inter_sample_shift", sample_shifts) - self._kwargs.update(dict(folder_path=str(folder_path), load_sync_channel=load_sync_channel)) + self._kwargs.update( + dict(folder_path=str(Path(folder_path).resolve().absolute()), load_sync_channel=load_sync_channel) + ) @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): - neo_kwargs = {"dirname": str(folder_path), "load_sync_channel": load_sync_channel} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute()), "load_sync_channel": load_sync_channel} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index bd4cbe2339..41cf190e15 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -35,11 +35,11 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(folder_path=str(folder_path))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(folder_path)} + neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index e5ac7e18bc..d9bafdcb67 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -104,7 +104,7 @@ def read_nwbfile( -------- >>> nwbfile = read_nwbfile("data.nwb", stream_mode="ros3") """ - file_path = str(file_path) + file_path = str(Path(file_path).resolve().absolute()) from pynwb import NWBHDF5IO, NWBFile if stream_mode == "fsspec": @@ -359,7 +359,7 @@ def __init__( self.set_property(property_name, values) if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).absolute()) + file_path = str(Path(file_path).resolve().absolute()) if stream_mode == "fsspec": # only add stream_cache_path to kwargs if it was passed as an argument if stream_cache_path is not None: @@ -477,16 +477,16 @@ def __init__( fs=fsspec.filesystem("http"), cache_storage=self.stream_cache_path, ) - self._file_path = self.cfs.open(str(file_path), "rb") + self._file_path = self.cfs.open(str(Path(file_path).resolve().absolute()), "rb") file = h5py.File(self._file_path) self.io = NWBHDF5IO(file=file, mode="r", load_namespaces=True) elif stream_mode == "ros3": - self._file_path = str(file_path) + self._file_path = str(Path(file_path).resolve().absolute()) self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True, driver="ros3") else: - self._file_path = str(file_path) + self._file_path = str(Path(file_path).resolve().absolute()) self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True) self._nwbfile = self.io.read() @@ -537,7 +537,7 @@ def __init__( self.set_property(prop_name, np.array(values)) if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).absolute()) + file_path = str(Path(file_path).resolve().absolute()) if stream_mode == "fsspec": stream_cache_path = str(Path(self.stream_cache_path).absolute()) self._kwargs = { diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index 130c0ce47e..10176a7502 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -169,7 +169,7 @@ def __init__(self, file_path, sampling_frequency, delimiter=","): if Path(file_path).is_file(): spike_clusters = sbio.SpikeClusters() - spike_clusters.fromCSV(str(file_path), None, delimiter=delimiter) + spike_clusters.fromCSV(str(Path(file_path).resolve().absolute()), None, delimiter=delimiter) else: raise FileNotFoundError(f"The ground truth file {file_path} could not be found") @@ -179,7 +179,7 @@ def __init__(self, file_path, sampling_frequency, delimiter=","): self.add_sorting_segment(sorting_segment) self._kwargs = { - "file_path": str(Path(file_path).absolute()), + "file_path": str(Path(file_path).resolve().absolute()), "sampling_frequency": sampling_frequency, "delimiter": delimiter, } diff --git a/src/spikeinterface/extractors/spykingcircusextractors.py b/src/spikeinterface/extractors/spykingcircusextractors.py index 11bf91b93b..5fab383759 100644 --- a/src/spikeinterface/extractors/spykingcircusextractors.py +++ b/src/spikeinterface/extractors/spykingcircusextractors.py @@ -85,7 +85,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sample_rate, unit_ids) self.add_sorting_segment(SpykingcircustSortingSegment(unit_ids, spiketrains)) - self._kwargs = {"folder_path": str(Path(folder_path).absolute())} + self._kwargs = {"folder_path": str(Path(folder_path).resolve().absolute())} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/yassextractors.py b/src/spikeinterface/extractors/yassextractors.py index 1fb4ad9555..2fbaa72226 100644 --- a/src/spikeinterface/extractors/yassextractors.py +++ b/src/spikeinterface/extractors/yassextractors.py @@ -56,7 +56,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(YassSortingSegment(spiketrains)) - self._kwargs = {"folder_path": str(folder_path)} + self._kwargs = {"folder_path": str(Path(folder_path).resolve().absolute())} self.extra_requirements.append("pyyaml") diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index cbaba31d02..138cceaeb6 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -120,8 +120,8 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo if output_folder is None: output_folder = cls.sorter_name + "_output" - #  .absolute() not anymore - output_folder = Path(output_folder) + # Resolve path + output_folder = Path(output_folder).resolve().absolute() sorter_output_folder = output_folder / "sorter_output" if output_folder.is_dir(): From 82852e5a7b28be6c086cb79f9cbd385cbca2370f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 16:21:27 +0200 Subject: [PATCH 059/166] Add zarr --- src/spikeinterface/core/zarrrecordingextractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/zarrrecordingextractor.py b/src/spikeinterface/core/zarrrecordingextractor.py index 5197e0fcc8..44e8674a73 100644 --- a/src/spikeinterface/core/zarrrecordingextractor.py +++ b/src/spikeinterface/core/zarrrecordingextractor.py @@ -49,7 +49,7 @@ def __init__(self, root_path: Union[Path, str], storage_options=None): root_path = Path(root_path) else: root_path_init = str(root_path) - root_path_kwarg = str(root_path.resolve().absolute()) + root_path_kwarg = str(Path(root_path).resolve().absolute()) else: root_path_init = root_path root_path_kwarg = root_path_init From 1d546c15ccfdc0c32e38b2681723b241ff4cbe89 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 16:24:59 +0200 Subject: [PATCH 060/166] Always recursive=True no matter relative_to for dump_to_json --- src/spikeinterface/core/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 484b159756..537e1ee19d 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -580,9 +580,8 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non # Writing paths as relative_to requires recursively expanding the dict if relative_to: - recursive = True - # We use relative_to == True to encode using the parent_folder relative_to = Path(file_path).parent if relative_to is True else relative_to + relative_to = relative_to.resolve().absolute() dump_dict = self.to_dict( include_annotations=True, From e02ce9d1145014078f08f95e559ac8270c03b8cf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 16:27:24 +0200 Subject: [PATCH 061/166] Ensure relative_to is a Path --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 537e1ee19d..c9892c2d42 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -580,7 +580,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non # Writing paths as relative_to requires recursively expanding the dict if relative_to: - relative_to = Path(file_path).parent if relative_to is True else relative_to + relative_to = Path(file_path).parent if relative_to is True else Path(relative_to) relative_to = relative_to.resolve().absolute() dump_dict = self.to_dict( From d6b205a5806ad2396172430da76c3e34781f755c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 16:39:13 +0200 Subject: [PATCH 062/166] Correct mistake --- src/spikeinterface/core/base.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index e6f70bf05e..897b25a71f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -533,7 +533,7 @@ def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: file_path.parent.mkdir(parents=True, exist_ok=True) folder_path = file_path.parent if Path(file_path).suffix == "": - file_path = folder_path / (str(Path(file_path).resolve().absolute()) + ext) + file_path = folder_path / (str(file_path) + ext) assert file_path.suffix in extensions, "'file_path' should have one of the following extensions:" " %s" % ( ", ".join(extensions) ) @@ -551,11 +551,9 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No If not None, files and folders are serialized relative to this path. If True, the relative folder is the parent folder. This means that file and folder paths in extractor objects kwargs are changed to be relative rather than absolute. """ - if str(Path(file_path).resolve().absolute()).endswith(".json"): + if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) - elif str(Path(file_path).resolve().absolute()).endswith(".pkl") or str( - Path(file_path).resolve().absolute() - ).endswith(".pickle"): + elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -645,12 +643,10 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo if file_path.is_file(): # standard case based on a file (json or pickle) - if str(Path(file_path).resolve().absolute()).endswith(".json"): + if str(file_path).endswith(".json"): with open(str(Path(file_path).resolve().absolute()), "r") as f: d = json.load(f) - elif str(Path(file_path).resolve().absolute()).endswith(".pkl") or str( - Path(file_path).resolve().absolute() - ).endswith(".pickle"): + elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): with open(str(Path(file_path).resolve().absolute()), "rb") as f: d = pickle.load(f) else: From 59e3f87f58a10cc10f7eacf44fa13511cc8171b5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 17:01:46 +0200 Subject: [PATCH 063/166] Import pathlib.Path in neo extractors --- src/spikeinterface/extractors/neoextractors/alphaomega.py | 2 ++ src/spikeinterface/extractors/neoextractors/axona.py | 2 ++ src/spikeinterface/extractors/neoextractors/biocam.py | 2 ++ src/spikeinterface/extractors/neoextractors/ced.py | 2 ++ src/spikeinterface/extractors/neoextractors/edf.py | 2 ++ src/spikeinterface/extractors/neoextractors/intan.py | 2 ++ src/spikeinterface/extractors/neoextractors/maxwell.py | 1 + src/spikeinterface/extractors/neoextractors/mcsraw.py | 2 ++ src/spikeinterface/extractors/neoextractors/neuralynx.py | 1 + src/spikeinterface/extractors/neoextractors/nix.py | 2 ++ src/spikeinterface/extractors/neoextractors/plexon.py | 2 ++ src/spikeinterface/extractors/neoextractors/spike2.py | 2 ++ src/spikeinterface/extractors/neoextractors/spikegadgets.py | 2 ++ src/spikeinterface/extractors/neoextractors/spikeglx.py | 1 + src/spikeinterface/extractors/neoextractors/tdt.py | 2 ++ 15 files changed, 27 insertions(+) diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index e546b9d971..ebe7634f72 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseEventExtractor diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index f8fb4a6733..00eb87003e 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 1886317ffc..5b9629c95b 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,3 +1,5 @@ +from pathlib import Path + import probeinterface as pi from spikeinterface.core.core_tools import define_function_from_class diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index c5d83d25e7..c6b5e35a0b 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index a27423e523..6d89fa2bf9 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index a456052668..7eacdecd0c 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 0a65e2e20a..329243ed95 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -1,4 +1,5 @@ import numpy as np +from pathlib import Path import probeinterface as pi diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 5a8ca1598d..e79e9fadae 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index b9e98692a6..328af53aac 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -1,4 +1,5 @@ from typing import Optional +from pathlib import Path from spikeinterface.core.core_tools import define_function_from_class diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 9da7323421..d91084efe0 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 9fee26ef9c..9b8adf25f5 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index aaa6c198f8..0efd2e221c 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index 02d0c92511..d6c85eb5de 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index a9a0dd0000..8566a8b0cd 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -1,6 +1,7 @@ from packaging import version import numpy as np +from pathlib import Path import neo import probeinterface as pi diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 41cf190e15..3a543a5131 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -1,3 +1,5 @@ +from pathlib import Path + from spikeinterface.core.core_tools import define_function_from_class from .neobaseextractor import NeoBaseRecordingExtractor From 7d3b7b7967adbfe0655b463bfe56d0bbc668ef3e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 17 Jul 2023 17:16:22 +0200 Subject: [PATCH 064/166] fix bug --- .../extractors/neoextractors/mearec.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index ab24034b9a..3f024bcc49 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -8,6 +8,22 @@ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor +def drop_neo_arguments_in_version_0_11_0(neo_kwargs): + # Temporary function until neo version 0.12.0 is released + from packaging.version import parse as parse_version + from importlib.metadata import version + + neo_version = version("neo") + minor_version = parse_version(neo_version).minor + + # The possibility of loading only spike_trains or only analog_signals is not present in neo <= 0.11.0 + if minor_version < 12: + neo_kwargs.pop("load_spiketrains") + neo_kwargs.pop("load_analogsignal") + + return neo_kwargs + + class MEArecRecordingExtractor(NeoBaseRecordingExtractor): """ Class for reading data from a MEArec simulated data. @@ -43,11 +59,10 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): self._kwargs.update({"file_path": str(file_path)}) @classmethod - def map_to_neo_kwargs( - cls, - file_path, - ): + def map_to_neo_kwargs(cls, file_path): neo_kwargs = {"filename": str(file_path), "load_spiketrains": False, "load_analogsignal": True} + # The possibility of loading only spike_trains or only analog_signals will be added in neo version 0.12.0 + neo_kwargs = drop_neo_arguments_in_version_0_11_0(neo_kwargs=neo_kwargs) return neo_kwargs @@ -68,6 +83,9 @@ def __init__(self, file_path: Union[str, Path]): @classmethod def map_to_neo_kwargs(cls, file_path): neo_kwargs = {"filename": str(file_path), "load_spiketrains": True, "load_analogsignal": False} + # The possibility of loading only spike_trains or only analog_signals will be added in neo version 0.12.0 + neo_kwargs = drop_neo_arguments_in_version_0_11_0(neo_kwargs=neo_kwargs) + return neo_kwargs def read_sampling_frequency(self, file_path: Union[str, Path]) -> float: From 6cc493a4e20a0d2d3d76b065446e29e81333ff39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Mon, 17 Jul 2023 18:10:32 +0200 Subject: [PATCH 065/166] Fixed `numba.jit` and `binary num_chan` warnings --- .../comparison/tests/test_groundtruthstudy.py | 3 +++ src/spikeinterface/core/tests/test_core_tools.py | 8 ++++---- .../sortingcomponents/clustering/sliding_nn.py | 2 +- src/spikeinterface/sortingcomponents/peak_detection.py | 4 ++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 6af2698211..04d3171a01 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -1,3 +1,4 @@ +import importlib import shutil import pytest from pathlib import Path @@ -33,6 +34,7 @@ def _setup_comparison_study(): study = GroundTruthStudy.create(study_folder, gt_dict) +@pytest.mark.skipif(importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'") def test_run_study_sorters(): study = GroundTruthStudy(study_folder) sorter_list = [ @@ -45,6 +47,7 @@ def test_run_study_sorters(): study.run_sorters(sorter_list) +@pytest.mark.skipif(importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'") def test_extract_sortings(): study = GroundTruthStudy(study_folder) diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 89a4143e19..3dc09f1e08 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -35,7 +35,7 @@ def test_write_binary_recording(tmp_path): # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_chan=num_channels, dtype=dtype + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype ) assert np.allclose(recorder_binary.get_traces(), recording.get_traces()) @@ -62,7 +62,7 @@ def test_write_binary_recording_offset(tmp_path): recorder_binary = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=sampling_frequency, - num_chan=num_channels, + num_channels=num_channels, dtype=dtype, file_offset=byte_offset, ) @@ -91,7 +91,7 @@ def test_write_binary_recording_parallel(tmp_path): # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_chan=num_channels, dtype=dtype + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype ) for segment_index in range(recording.get_num_segments()): binary_traces = recorder_binary.get_traces(segment_index=segment_index) @@ -118,7 +118,7 @@ def test_write_binary_recording_multiple_segment(tmp_path): # Check if written data matches original data recorder_binary = BinaryRecordingExtractor( - file_paths=file_paths, sampling_frequency=sampling_frequency, num_chan=num_channels, dtype=dtype + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype ) for segment_index in range(recording.get_num_segments()): diff --git a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py index 24fea6429f..68b34a7041 100644 --- a/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py +++ b/src/spikeinterface/sortingcomponents/clustering/sliding_nn.py @@ -367,7 +367,7 @@ def main_function(cls, recording, peaks, params): if HAVE_NUMBA: - @numba.jit(fastmath=True, cache=True) + @numba.jit(nopython=True, fastmath=True, cache=True) def sparse_euclidean(x, y, n_samples, n_dense): """Euclidean distance metric over sparse vectors, where first n_dense elements are indices, and n_samples is the length of the second dimension diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index c495a3bfa4..df3374b39d 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -640,7 +640,7 @@ def detect_peaks(cls, traces, peak_sign, abs_threholds, exclude_sweep_size, devi if HAVE_NUMBA: - @numba.jit(parallel=False) + @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_pos( traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask ): @@ -665,7 +665,7 @@ def _numba_detect_peak_pos( break return peak_mask - @numba.jit(parallel=False) + @numba.jit(nopython=True, parallel=False) def _numba_detect_peak_neg( traces, traces_center, peak_mask, exclude_sweep_size, abs_threholds, peak_sign, neighbours_mask ): From ce5ea3b881289a0b653eab4b79f705b79e8828cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jul 2023 16:11:20 +0000 Subject: [PATCH 066/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comparison/tests/test_groundtruthstudy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 04d3171a01..9d495e64c5 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -34,7 +34,9 @@ def _setup_comparison_study(): study = GroundTruthStudy.create(study_folder, gt_dict) -@pytest.mark.skipif(importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'") +@pytest.mark.skipif( + importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'" +) def test_run_study_sorters(): study = GroundTruthStudy(study_folder) sorter_list = [ @@ -47,7 +49,9 @@ def test_run_study_sorters(): study.run_sorters(sorter_list) -@pytest.mark.skipif(importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'") +@pytest.mark.skipif( + importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'" +) def test_extract_sortings(): study = GroundTruthStudy(study_folder) From 936746b4cc03eeebd13fd4fcccb55bb42007fabf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 17 Jul 2023 18:14:22 +0200 Subject: [PATCH 067/166] Initial work to refactor widget in a unique class. --- src/spikeinterface/widgets/base.py | 89 +++++-- src/spikeinterface/widgets/unit_locations.py | 245 ++++++++++++++++++- 2 files changed, 310 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 9a914bf28d..7b62dc3507 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -19,15 +19,63 @@ def set_default_plotter_backend(backend): default_backend_ = backend + +backend_kwargs_desc = { + "matplotlib": { + "figure": "Matplotlib figure. When None, it is created. Default None", + "ax": "Single matplotlib axis. When None, it is created. Default None", + "axes": "Multiple matplotlib axes. When None, they is created. Default None", + "ncols": "Number of columns to create in subplots. Default 5", + "figsize": "Size of matplotlib figure. Default None", + "figtitle": "The figure title. Default None", + }, + 'sortingview': { + "generate_url": "If True, the figurl URL is generated and printed. Default True", + "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", + "figlabel": "The figurl figure label. Default None", + "height": "The height of the sortingview View in jupyter. Default None", + }, + "ipywidgets" : { + "width_cm": "Width of the figure in cm (default 10)", + "height_cm": "Height of the figure in cm (default 6)", + "display": "If True, widgets are immediately displayed", + }, + +} + +default_backend_kwargs = { + "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, + "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, + "ipywidgets" : {"width_cm": 25, "height_cm": 10, "display": True}, +} + + + class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, plot_data=None, backend=None, **backend_kwargs): + def __init__(self, data_plot=None, backend=None, **backend_kwargs): # every widgets must prepare a dict "plot_data" in the init - self.plot_data = plot_data + self.data_plot = data_plot self.backend = backend - self.backend_kwargs = backend_kwargs + + + for k in backend_kwargs: + if k not in default_backend_kwargs[backend]: + raise Exception( + f"{k} is not a valid plot argument or backend keyword argument. " + f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + ) + backend_kwargs_ = default_backend_kwargs[backend].copy() + backend_kwargs_.update(backend_kwargs) + + self.backend_kwargs = backend_kwargs_ + + + func = getattr(self, f'plot_{backend}') + func(self) + def check_backend(self, backend): if backend is None: @@ -36,15 +84,16 @@ def check_backend(self, backend): f"{backend} backend not available! Available backends are: " f"{list(self.possible_backends.keys())}" ) return backend + - def check_backend_kwargs(self, plotter, backend, **backend_kwargs): - plotter_kwargs = plotter.default_backend_kwargs - for k in backend_kwargs: - if k not in plotter_kwargs: - raise Exception( - f"{k} is not a valid plot argument or backend keyword argument. " - f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" - ) + # def check_backend_kwargs(self, plotter, backend, **backend_kwargs): + # plotter_kwargs = plotter.default_backend_kwargs + # for k in backend_kwargs: + # if k not in plotter_kwargs: + # raise Exception( + # f"{k} is not a valid plot argument or backend keyword argument. " + # f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + # ) def do_plot(self, backend, **backend_kwargs): backend = self.check_backend(backend) @@ -74,17 +123,17 @@ def check_extensions(waveform_extractor, extensions): raise Exception(error_msg) -class BackendPlotter: - backend = "" +# class BackendPlotter: +# backend = "" - @classmethod - def register(cls, widget_cls): - widget_cls.register_backend(cls) +# @classmethod +# def register(cls, widget_cls): +# widget_cls.register_backend(cls) - def update_backend_kwargs(self, **backend_kwargs): - backend_kwargs_ = self.default_backend_kwargs.copy() - backend_kwargs_.update(backend_kwargs) - return backend_kwargs_ +# def update_backend_kwargs(self, **backend_kwargs): +# backend_kwargs_ = self.default_backend_kwargs.copy() +# backend_kwargs_.update(backend_kwargs) +# return backend_kwargs_ def copy_signature(source_fct): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 2c58fdfe45..be9d9cacc8 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -1,7 +1,9 @@ import numpy as np from typing import Union -from .base import BaseWidget +from probeinterface import ProbeGroup + +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -31,7 +33,7 @@ class UnitLocationsWidget(BaseWidget): If True, the axis is set to off, default False (matplotlib backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -62,7 +64,7 @@ def __init__( if unit_ids is None: unit_ids = sorting.unit_ids - plot_data = dict( + data_plot = dict( all_unit_ids=sorting.unit_ids, unit_locations=unit_locations, sorting=sorting, @@ -78,4 +80,239 @@ def __init__( hide_axis=hide_axis, ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + + + + dp = to_attr(self.data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(self.backend_kwargs) + + + unit_locations = dp.unit_locations + + probegroup = ProbeGroup.from_dict(dp.probegroup_dict) + probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) + + for probe in probegroup.probes: + text_on_contact = None + if dp.with_channel_ids: + text_on_contact = dp.channel_ids + + poly_contact, poly_contour = plot_probe( + probe, + ax=self.ax, + contacts_colors="w", + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + text_on_contact=text_on_contact, + ) + poly_contact.set_zorder(2) + if poly_contour is not None: + poly_contour.set_zorder(1) + + self.ax.set_title("") + + # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) + width = height = 10 + ellipse_kwargs = dict(width=width, height=height, lw=2) + + if dp.plot_all_units: + unit_colors = {} + unit_ids = dp.all_unit_ids + for unit in dp.all_unit_ids: + if unit not in dp.unit_ids: + unit_colors[unit] = "gray" + else: + unit_colors[unit] = dp.unit_colors[unit] + else: + unit_ids = dp.unit_ids + unit_colors = dp.unit_colors + labels = dp.unit_ids + + patches = [ + Ellipse( + (unit_locations[unit]), + color=unit_colors[unit], + zorder=5 if unit in dp.unit_ids else 3, + alpha=0.9 if unit in dp.unit_ids else 0.5, + **ellipse_kwargs, + ) + for i, unit in enumerate(unit_ids) + ] + for p in patches: + self.ax.add_patch(p) + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in dp.unit_ids + ] + + if dp.plot_legend: + if hasattr(self, 'legend') and self.legend is not None: + # if self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + if dp.hide_axis: + self.ax.axis("off") + + def plot_sortingview(self): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs = self.backend_kwargs + dp = to_attr(self.data_plot) + + # ensure serializable for sortingview + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + + unit_items = [] + for unit_id in unit_ids: + unit_items.append( + vv.UnitLocationsItem( + unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) + ) + ) + + v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], + ) + else: + self.view = v_unit_locations + + # self.handle_display_and_url(view, **backend_kwargs) + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm + ) + + self.controller = unit_controller + + # mpl_plotter = MplUnitLocationsPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self.update_widget) + + self.widget = widgets.AppLayout( + center=fig.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + self.updater(None) + + if backend_kwargs["display"]: + self.check_backend() + display(self.widget) + + def update_widget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + # data_plot = self.next_data_plot + self.data_plot["unit_ids"] = unit_ids + self.data_plot["plot_all_units"] = True + self.data_plot["plot_legend"] = True + self.data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + + + + +class PlotUpdater: + def __init__(self, data_plot, mpl_plotter, ax, controller): + self.data_plot = data_plot + self.mpl_plotter = mpl_plotter + self.ax = ax + self.controller = controller + + self.next_data_plot = data_plot.copy() + + def __call__(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_all_units"] = True + data_plot["plot_legend"] = True + data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + UnitLocationsWidget.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + From 6bbd5805747c0cb4a37a7ba979b3c3f3d0f069f3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 19:37:47 +0200 Subject: [PATCH 068/166] Keep resolve() at base only --- src/spikeinterface/core/base.py | 4 ++-- src/spikeinterface/core/binaryfolder.py | 2 +- src/spikeinterface/core/binaryrecordingextractor.py | 2 +- src/spikeinterface/core/npyfoldersnippets.py | 2 +- src/spikeinterface/core/npysnippetsextractor.py | 2 +- src/spikeinterface/core/npzfolder.py | 2 +- src/spikeinterface/core/zarrrecordingextractor.py | 2 +- src/spikeinterface/exporters/to_phy.py | 2 +- .../extractors/cellexplorersortingextractor.py | 2 +- src/spikeinterface/extractors/combinatoextractors.py | 4 ++-- src/spikeinterface/extractors/hdsortextractors.py | 2 +- .../extractors/herdingspikesextractors.py | 2 +- src/spikeinterface/extractors/matlabhelpers.py | 2 +- src/spikeinterface/extractors/mdaextractors.py | 4 ++-- .../extractors/neoextractors/alphaomega.py | 6 +++--- src/spikeinterface/extractors/neoextractors/axona.py | 2 +- .../extractors/neoextractors/biocam.py | 4 ++-- .../extractors/neoextractors/blackrock.py | 6 +++--- src/spikeinterface/extractors/neoextractors/ced.py | 4 ++-- src/spikeinterface/extractors/neoextractors/edf.py | 6 +++--- src/spikeinterface/extractors/neoextractors/intan.py | 4 ++-- .../extractors/neoextractors/maxwell.py | 4 ++-- .../extractors/neoextractors/mcsraw.py | 4 ++-- .../extractors/neoextractors/mearec.py | 8 ++++---- .../extractors/neoextractors/neuralynx.py | 6 +++--- .../extractors/neoextractors/neuroscope.py | 4 +++- src/spikeinterface/extractors/neoextractors/nix.py | 4 ++-- .../extractors/neoextractors/openephys.py | 10 +++++----- .../extractors/neoextractors/plexon.py | 8 ++++---- .../extractors/neoextractors/spike2.py | 4 ++-- .../extractors/neoextractors/spikegadgets.py | 4 ++-- .../extractors/neoextractors/spikeglx.py | 6 ++---- src/spikeinterface/extractors/neoextractors/tdt.py | 4 ++-- src/spikeinterface/extractors/nwbextractors.py | 12 ++++++------ src/spikeinterface/extractors/shybridextractors.py | 4 ++-- .../extractors/spykingcircusextractors.py | 2 +- src/spikeinterface/extractors/tests/common_tests.py | 1 - src/spikeinterface/extractors/yassextractors.py | 2 +- src/spikeinterface/sorters/basesorter.py | 2 +- 39 files changed, 77 insertions(+), 78 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 897b25a71f..46d0214d27 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -644,10 +644,10 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo if file_path.is_file(): # standard case based on a file (json or pickle) if str(file_path).endswith(".json"): - with open(str(Path(file_path).resolve().absolute()), "r") as f: + with open(str(Path(file_path)), "r") as f: d = json.load(f) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - with open(str(Path(file_path).resolve().absolute()), "rb") as f: + with open(str(Path(file_path)), "rb") as f: d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index 1a95d4b2bb..d185111b8c 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(Path(folder_path).resolve().absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).absolute())) self._bin_kwargs = d["kwargs"] if "num_channels" not in self._bin_kwargs: assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json" diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index c41ea1e095..72a95637f6 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -116,7 +116,7 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(Path(e).resolve().absolute()) for e in file_path_list], + "file_paths": [str(Path(e).absolute()) for e in file_path_list], "sampling_frequency": sampling_frequency, "t_starts": t_starts, "num_channels": num_channels, diff --git a/src/spikeinterface/core/npyfoldersnippets.py b/src/spikeinterface/core/npyfoldersnippets.py index 04d01954fb..c002bbe044 100644 --- a/src/spikeinterface/core/npyfoldersnippets.py +++ b/src/spikeinterface/core/npyfoldersnippets.py @@ -48,7 +48,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(Path(folder_path).resolve().absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).absolute())) self._bin_kwargs = d["kwargs"] diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 12592dfee8..80979ce6c9 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -47,7 +47,7 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(Path(f).resolve().absolute()) for f in file_paths], + "file_paths": [str(Path(f).absolute()) for f in file_paths], "sampling_frequency": sampling_frequency, "channel_ids": channel_ids, "nbefore": nbefore, diff --git a/src/spikeinterface/core/npzfolder.py b/src/spikeinterface/core/npzfolder.py index 0d79177dd2..5cdc46c353 100644 --- a/src/spikeinterface/core/npzfolder.py +++ b/src/spikeinterface/core/npzfolder.py @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(Path(folder_path).resolve().absolute())) + self._kwargs = dict(folder_path=str(Path(folder_path).absolute())) self._npz_kwargs = d["kwargs"] diff --git a/src/spikeinterface/core/zarrrecordingextractor.py b/src/spikeinterface/core/zarrrecordingextractor.py index 44e8674a73..4dc94a24dd 100644 --- a/src/spikeinterface/core/zarrrecordingextractor.py +++ b/src/spikeinterface/core/zarrrecordingextractor.py @@ -49,7 +49,7 @@ def __init__(self, root_path: Union[Path, str], storage_options=None): root_path = Path(root_path) else: root_path_init = str(root_path) - root_path_kwarg = str(Path(root_path).resolve().absolute()) + root_path_kwarg = str(Path(root_path).absolute()) else: root_path_init = root_path root_path_kwarg = root_path_init diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index df1a00471f..58804037b0 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -111,7 +111,7 @@ def export_to_phy( if len(unit_ids) == 0: raise Exception("No non-empty units in the sorting result, can't save to Phy.") - output_folder = Path(output_folder).resolve().absolute() + output_folder = Path(output_folder).absolute() if output_folder.is_dir(): if remove_if_exists: shutil.rmtree(output_folder) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index b9a4c6c576..b40b998103 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -203,7 +203,7 @@ def _retrieve_sampling_frequency_from_session_info_file(self) -> float: if self.session_info_file_path is None: self.session_info_file_path = self.session_path / f"{self.session_id}.sessionInfo.mat" - self.session_info_file_path = Path(self.session_info_file_path).resolve().absolute() + self.session_info_file_path = Path(self.session_info_file_path).absolute() assert ( self.session_info_file_path.is_file() ), f"No {self.session_id}.sessionInfo.mat file found in the {self.session_path}!, can't inferr sampling rate, please pass the sampling rate at initialization" diff --git a/src/spikeinterface/extractors/combinatoextractors.py b/src/spikeinterface/extractors/combinatoextractors.py index 7a682bc4f1..fa2bdde450 100644 --- a/src/spikeinterface/extractors/combinatoextractors.py +++ b/src/spikeinterface/extractors/combinatoextractors.py @@ -44,7 +44,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign folder_path = Path(folder_path) assert folder_path.is_dir(), "Folder {} doesn't exist".format(folder_path) if sampling_frequency is None: - h5_path = str(Path(folder_path).resolve().absolute()) + ".h5" + h5_path = str(Path(folder_path).absolute()) + ".h5" if Path(h5_path).exists(): with h5py.File(h5_path, mode="r") as f: sampling_frequency = f["sr"][0] @@ -85,7 +85,7 @@ def __init__(self, folder_path, sampling_frequency=None, user="simple", det_sign self.add_sorting_segment(CombinatoSortingSegment(spiketrains)) self.set_property("unsorted", np.array([metadata[u]["group_type"] == 0 for u in range(unit_counter)])) self.set_property("artifact", np.array([metadata[u]["group_type"] == -1 for u in range(unit_counter)])) - self._kwargs = {"folder_path": str(Path(folder_path).resolve().absolute()), "user": user, "det_sign": det_sign} + self._kwargs = {"folder_path": str(Path(folder_path).absolute()), "user": user, "det_sign": det_sign} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/hdsortextractors.py b/src/spikeinterface/extractors/hdsortextractors.py index 3906fb8457..178596d052 100644 --- a/src/spikeinterface/extractors/hdsortextractors.py +++ b/src/spikeinterface/extractors/hdsortextractors.py @@ -108,7 +108,7 @@ def __init__(self, file_path, keep_good_only=True): self.set_property("template", np.array(templates)) self.set_property("template_frames_cut_before", np.array(templates_frames_cut_before)) - self._kwargs = {"file_path": str(Path(file_path).resolve().absolute()), "keep_good_only": keep_good_only} + self._kwargs = {"file_path": str(Path(file_path).absolute()), "keep_good_only": keep_good_only} # TODO features # ~ for uc, unit in enumerate(units): diff --git a/src/spikeinterface/extractors/herdingspikesextractors.py b/src/spikeinterface/extractors/herdingspikesextractors.py index 695eba9750..1fc71b1cd0 100644 --- a/src/spikeinterface/extractors/herdingspikesextractors.py +++ b/src/spikeinterface/extractors/herdingspikesextractors.py @@ -57,7 +57,7 @@ def __init__(self, file_path, load_unit_info=True): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(HerdingspikesSortingSegment(unit_ids, spike_times, spike_ids)) - self._kwargs = {"file_path": str(Path(file_path).resolve().absolute()), "load_unit_info": load_unit_info} + self._kwargs = {"file_path": str(Path(file_path).absolute()), "load_unit_info": load_unit_info} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/matlabhelpers.py b/src/spikeinterface/extractors/matlabhelpers.py index 7b61ed17cf..4f22d25339 100644 --- a/src/spikeinterface/extractors/matlabhelpers.py +++ b/src/spikeinterface/extractors/matlabhelpers.py @@ -26,7 +26,7 @@ def __init__(self, file_path): if not file_path.is_file(): raise ValueError(f"Specified file path '{file_path}' is not a file.") - self._kwargs = {"file_path": str(Path(file_path).resolve().absolute())} + self._kwargs = {"file_path": str(Path(file_path).absolute())} try: # load old-style (up to 7.2) .mat file self._data = loadmat(file_path, matlab_compatible=True) diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 3adf9fcb62..815c617677 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -196,7 +196,7 @@ class MdaSortingExtractor(BaseSorting): name = "mda" def __init__(self, file_path, sampling_frequency): - firings = readmda(str(Path(file_path).resolve().absolute())) + firings = readmda(str(Path(file_path).absolute())) labels = firings[2, :] unit_ids = np.unique(labels).astype(int) BaseSorting.__init__(self, unit_ids=unit_ids, sampling_frequency=sampling_frequency) @@ -205,7 +205,7 @@ def __init__(self, file_path, sampling_frequency): self.add_sorting_segment(sorting_segment) self._kwargs = { - "file_path": str(Path(file_path).resolve().absolute()), + "file_path": str(Path(file_path).absolute()), "sampling_frequency": sampling_frequency, } diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index ebe7634f72..d806a07fe9 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -34,12 +34,12 @@ def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=Non NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()), lsx_files=lsx_files)) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()), lsx_files=lsx_files)) @classmethod def map_to_neo_kwargs(cls, folder_path, lsx_files=None): neo_kwargs = { - "dirname": str(Path(folder_path).resolve().absolute()), + "dirname": str(Path(folder_path).absolute()), "lsx_files": lsx_files, } return neo_kwargs @@ -60,7 +60,7 @@ def __init__(self, folder_path): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} + neo_kwargs = {"dirname": str(Path(folder_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index 00eb87003e..0772ea0a26 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -30,7 +30,7 @@ def __init__(self, file_path, all_annotations=False): @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 5b9629c95b..f3483b1cd5 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -61,7 +61,7 @@ def __init__( self._kwargs.update( { - "file_path": str(Path(file_path).resolve().absolute()), + "file_path": str(Path(file_path).absolute()), "mea_pitch": mea_pitch, "electrode_width": electrode_width, } @@ -69,7 +69,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index 98ddf4204b..f36995e37e 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -56,11 +56,11 @@ def __init__( use_names_as_ids=use_names_as_ids, **neo_kwargs, ) - self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs @@ -115,7 +115,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index c6b5e35a0b..49939c4897 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -36,12 +36,12 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()))) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) self.extra_requirements.append("neo[ced]") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 6d89fa2bf9..309a686b69 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -29,16 +29,16 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): name = "edf" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("neo[edf]") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 7eacdecd0c..330a1e4682 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -38,11 +38,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()))) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 329243ed95..602a054ed8 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -71,11 +71,11 @@ def __init__( probe = pi.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) self.set_property("electrode", self.get_property("contact_vector")["electrode"]) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), rec_name=rec_name)) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod def map_to_neo_kwargs(cls, file_path, rec_name=None): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute()), "rec_name": rec_name} + neo_kwargs = {"filename": str(Path(file_path).absolute()), "rec_name": rec_name} return neo_kwargs def install_maxwell_plugin(self, force_download=False): diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index e79e9fadae..84d0ff1b7d 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -42,11 +42,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()))) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 82e2f2b858..9826464ecf 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -40,7 +40,7 @@ def __init__(self, file_path: Union[str, Path], all_annotations: bool = False): if hasattr(self.neo_reader._recgen, "gain_to_uV"): self.set_channel_gains(self.neo_reader._recgen.gain_to_uV) - self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs( @@ -48,7 +48,7 @@ def map_to_neo_kwargs( file_path, ): neo_kwargs = { - "filename": str(Path(file_path).resolve().absolute()), + "filename": str(Path(file_path).absolute()), "load_spiketrains": False, "load_analogsignal": True, } @@ -67,12 +67,12 @@ def __init__(self, file_path: Union[str, Path]): sampling_frequency = self.read_sampling_frequency(file_path=file_path) NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, use_format_ids=True, **neo_kwargs) - self._kwargs = {"file_path": str(Path(file_path).resolve().absolute())} + self._kwargs = {"file_path": str(Path(file_path).absolute())} @classmethod def map_to_neo_kwargs(cls, file_path): neo_kwargs = { - "filename": str(Path(file_path).resolve().absolute()), + "filename": str(Path(file_path).absolute()), "load_spiketrains": True, "load_analogsignal": False, } diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 328af53aac..351fbd6e44 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -33,11 +33,11 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, all_annotation NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} + neo_kwargs = {"dirname": str(Path(folder_path).absolute())} return neo_kwargs @@ -91,7 +91,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} + neo_kwargs = {"dirname": str(Path(folder_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index fbea66d62c..232b073332 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -51,7 +51,9 @@ def __init__(self, file_path, xml_file_path=None, stream_id=None, stream_name=No NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), xml_file_path=xml_file_path)) + if xml_file_path is not None: + xml_file_path = str(Path(xml_file_path).absolute()) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), xml_file_path=xml_file_path)) @classmethod def map_to_neo_kwargs(cls, file_path, xml_file_path=None): diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index d91084efe0..2cd336083c 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -39,12 +39,12 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), stream_id=stream_id)) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) self.extra_requirements.append("neo[nixio]") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index ea0fbfad9b..9ec104cf02 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -61,11 +61,11 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} + neo_kwargs = {"dirname": str(Path(folder_path).absolute())} return neo_kwargs @@ -204,7 +204,7 @@ def __init__( self._kwargs.update( dict( - folder_path=str(Path(folder_path).resolve().absolute()), + folder_path=str(Path(folder_path).absolute()), load_sync_channel=load_sync_channel, load_sync_timestamps=load_sync_timestamps, experiment_names=experiment_names, @@ -214,7 +214,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False, experiment_names=None): neo_kwargs = { - "dirname": str(Path(folder_path).resolve().absolute()), + "dirname": str(Path(folder_path).absolute()), "load_sync_channel": load_sync_channel, "experiment_names": experiment_names, } @@ -248,7 +248,7 @@ def __init__(self, folder_path, block_index=None): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} + neo_kwargs = {"dirname": str(Path(folder_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 9b8adf25f5..d07037aebf 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -32,11 +32,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs @@ -63,11 +63,11 @@ def __init__(self, file_path): self.neo_reader.parse_header() sampling_frequency = self.neo_reader._global_ssampling_rate NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) - self._kwargs = {"file_path": str(Path(file_path).resolve().absolute())} + self._kwargs = {"file_path": str(Path(file_path).absolute())} @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 0efd2e221c..5bc03aac0d 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -33,12 +33,12 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update({"file_path": str(Path(file_path).resolve().absolute())}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) self.extra_requirements.append("sonpy") @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index d6c85eb5de..b5b5ace95d 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -32,11 +32,11 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) - self._kwargs.update(dict(file_path=str(Path(file_path).resolve().absolute()), stream_id=stream_id)) + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), stream_id=stream_id)) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).resolve().absolute())} + neo_kwargs = {"filename": str(Path(file_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 8566a8b0cd..74e7008d01 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -91,13 +91,11 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ self.set_property("inter_sample_shift", sample_shifts) - self._kwargs.update( - dict(folder_path=str(Path(folder_path).resolve().absolute()), load_sync_channel=load_sync_channel) - ) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()), load_sync_channel=load_sync_channel)) @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute()), "load_sync_channel": load_sync_channel} + neo_kwargs = {"dirname": str(Path(folder_path).absolute()), "load_sync_channel": load_sync_channel} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 3a543a5131..6de46498b2 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -37,11 +37,11 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No all_annotations=all_annotations, **neo_kwargs, ) - self._kwargs.update(dict(folder_path=str(Path(folder_path).resolve().absolute()))) + self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()))) @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).resolve().absolute())} + neo_kwargs = {"dirname": str(Path(folder_path).absolute())} return neo_kwargs diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d9bafdcb67..d0b56342dd 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -104,7 +104,7 @@ def read_nwbfile( -------- >>> nwbfile = read_nwbfile("data.nwb", stream_mode="ros3") """ - file_path = str(Path(file_path).resolve().absolute()) + file_path = str(Path(file_path).absolute()) from pynwb import NWBHDF5IO, NWBFile if stream_mode == "fsspec": @@ -359,7 +359,7 @@ def __init__( self.set_property(property_name, values) if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).resolve().absolute()) + file_path = str(Path(file_path).absolute()) if stream_mode == "fsspec": # only add stream_cache_path to kwargs if it was passed as an argument if stream_cache_path is not None: @@ -477,16 +477,16 @@ def __init__( fs=fsspec.filesystem("http"), cache_storage=self.stream_cache_path, ) - self._file_path = self.cfs.open(str(Path(file_path).resolve().absolute()), "rb") + self._file_path = self.cfs.open(str(Path(file_path).absolute()), "rb") file = h5py.File(self._file_path) self.io = NWBHDF5IO(file=file, mode="r", load_namespaces=True) elif stream_mode == "ros3": - self._file_path = str(Path(file_path).resolve().absolute()) + self._file_path = str(Path(file_path).absolute()) self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True, driver="ros3") else: - self._file_path = str(Path(file_path).resolve().absolute()) + self._file_path = str(Path(file_path).absolute()) self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True) self._nwbfile = self.io.read() @@ -537,7 +537,7 @@ def __init__( self.set_property(prop_name, np.array(values)) if stream_mode not in ["fsspec", "ros3"]: - file_path = str(Path(file_path).resolve().absolute()) + file_path = str(Path(file_path).absolute()) if stream_mode == "fsspec": stream_cache_path = str(Path(self.stream_cache_path).absolute()) self._kwargs = { diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index 10176a7502..cefa738ae1 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -169,7 +169,7 @@ def __init__(self, file_path, sampling_frequency, delimiter=","): if Path(file_path).is_file(): spike_clusters = sbio.SpikeClusters() - spike_clusters.fromCSV(str(Path(file_path).resolve().absolute()), None, delimiter=delimiter) + spike_clusters.fromCSV(str(Path(file_path).absolute()), None, delimiter=delimiter) else: raise FileNotFoundError(f"The ground truth file {file_path} could not be found") @@ -179,7 +179,7 @@ def __init__(self, file_path, sampling_frequency, delimiter=","): self.add_sorting_segment(sorting_segment) self._kwargs = { - "file_path": str(Path(file_path).resolve().absolute()), + "file_path": str(Path(file_path).absolute()), "sampling_frequency": sampling_frequency, "delimiter": delimiter, } diff --git a/src/spikeinterface/extractors/spykingcircusextractors.py b/src/spikeinterface/extractors/spykingcircusextractors.py index 5fab383759..11bf91b93b 100644 --- a/src/spikeinterface/extractors/spykingcircusextractors.py +++ b/src/spikeinterface/extractors/spykingcircusextractors.py @@ -85,7 +85,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sample_rate, unit_ids) self.add_sorting_segment(SpykingcircustSortingSegment(unit_ids, spiketrains)) - self._kwargs = {"folder_path": str(Path(folder_path).resolve().absolute())} + self._kwargs = {"folder_path": str(Path(folder_path).absolute())} self.extra_requirements.append("h5py") diff --git a/src/spikeinterface/extractors/tests/common_tests.py b/src/spikeinterface/extractors/tests/common_tests.py index c1a98698b0..858c86d92a 100644 --- a/src/spikeinterface/extractors/tests/common_tests.py +++ b/src/spikeinterface/extractors/tests/common_tests.py @@ -38,7 +38,6 @@ def test_open(self): # test streams and blocks retrieval full_path = self.get_full_path(path) - rec = self.ExtractorClass(full_path, **kwargs) assert hasattr(rec, "extra_requirements") diff --git a/src/spikeinterface/extractors/yassextractors.py b/src/spikeinterface/extractors/yassextractors.py index 2fbaa72226..bb04c21533 100644 --- a/src/spikeinterface/extractors/yassextractors.py +++ b/src/spikeinterface/extractors/yassextractors.py @@ -56,7 +56,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sampling_frequency, unit_ids) self.add_sorting_segment(YassSortingSegment(spiketrains)) - self._kwargs = {"folder_path": str(Path(folder_path).resolve().absolute())} + self._kwargs = {"folder_path": str(Path(folder_path).absolute())} self.extra_requirements.append("pyyaml") diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 138cceaeb6..7ea2fe5a23 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -121,7 +121,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo output_folder = cls.sorter_name + "_output" # Resolve path - output_folder = Path(output_folder).resolve().absolute() + output_folder = Path(output_folder).absolute() sorter_output_folder = output_folder / "sorter_output" if output_folder.is_dir(): From c36a40ba087ee28da9adffae98f827b9cb9a33d2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 19:41:21 +0200 Subject: [PATCH 069/166] Pin numpy<1.25 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1b6d116e4b..a456c23755 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] dependencies = [ - "numpy", + "numpy<1.25", "neo>=0.11.1", "joblib", "threadpoolctl", From 0b91327330068bd059ccc9e3171f6f42619c6b4a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 19:45:01 +0200 Subject: [PATCH 070/166] oups --- src/spikeinterface/widgets/matplotlib/motion.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py index c4f32e4e75..8a89351c8a 100644 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ b/src/spikeinterface/widgets/matplotlib/motion.py @@ -42,14 +42,13 @@ def do_plot(self, data_plot, **backend_kwargs): x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + times[0] - x = times[dp.peaks["sample_index"]] + temporal_bins_plot = dp.temporal_bins + dp.times[0] + x = dp.times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, - dp, - sampling_frequency, + dp.sampling_frequency, dp.motion, dp.temporal_bins, dp.spatial_bins, From e0a327be3530547d20b46fc2ea43782365395685 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 19:48:58 +0200 Subject: [PATCH 071/166] Simplify skip if pytests --- .../comparison/tests/test_groundtruthstudy.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 9d495e64c5..70f8a63c8c 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -7,6 +7,13 @@ from spikeinterface.sorters import installed_sorters from spikeinterface.comparison import GroundTruthStudy +try: + import tridesclous + + HAVE_TDC = True +except ImportError: + HAVE_TDC = False + if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "comparison" @@ -34,9 +41,7 @@ def _setup_comparison_study(): study = GroundTruthStudy.create(study_folder, gt_dict) -@pytest.mark.skipif( - importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'" -) +@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") def test_run_study_sorters(): study = GroundTruthStudy(study_folder) sorter_list = [ @@ -49,9 +54,7 @@ def test_run_study_sorters(): study.run_sorters(sorter_list) -@pytest.mark.skipif( - importlib.util.find_spec("tridesclous") is None, reason="Test requires Python package 'tridesclous'" -) +@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") def test_extract_sortings(): study = GroundTruthStudy(study_folder) From 4ed94bb4e942cd7c60dd013265b9dc1bdaa82e83 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 17 Jul 2023 19:51:32 +0200 Subject: [PATCH 072/166] Install hdbscan with conda --- docs_rtd.yml | 1 + pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs_rtd.yml b/docs_rtd.yml index c4e1fb378c..975aafb46b 100644 --- a/docs_rtd.yml +++ b/docs_rtd.yml @@ -5,5 +5,6 @@ dependencies: - python=3.10 - pip - datalad + - hdbscan - pip: - -e .[docs] diff --git a/pyproject.toml b/pyproject.toml index a456c23755..1b6d116e4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ ] dependencies = [ - "numpy<1.25", + "numpy", "neo>=0.11.1", "joblib", "threadpoolctl", From 22708ae18739c9fa1703a2649a95fc613c012732 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jul 2023 07:01:03 +0200 Subject: [PATCH 073/166] [pre-commit.ci] pre-commit autoupdate (#1839) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 23.3.0 → 23.7.0](https://github.com/psf/black/compare/23.3.0...23.7.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 816e4e24d6..ced1ee6a2f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.7.0 hooks: - id: black files: ^src/ From edc1b3b4ef756a8ef93e0ae1f88e17c5a320559d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 08:28:49 +0200 Subject: [PATCH 074/166] howto plot_motion --- doc/how_to/handle_drift.rst | 2 +- examples/how_to/handle_drift.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index 53b68e8c17..c0a27ff0a3 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -204,7 +204,7 @@ A few comments on the figures: # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(motion_info, recording=rec, figure=fig, depth_lim=(400, 600), + si.plot_motion(motion_info, figure=fig, depth_lim=(400, 600), color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) fig.suptitle(f"{preset=}") diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index 26841f49dd..a1671a7424 100644 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -1,3 +1,19 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# formats: py,ipynb +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.14.6 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + # %matplotlib inline # %load_ext autoreload # %autoreload 2 @@ -119,8 +135,9 @@ def preprocess_chain(rec): # and plot fig = plt.figure(figsize=(14, 8)) - si.plot_motion(motion_info, recording=rec, figure=fig, depth_lim=(400, 600), + si.plot_motion(motion_info, figure=fig, depth_lim=(400, 600), color_amplitude=True, amplitude_cmap='inferno', scatter_decimate=10) + fig.suptitle(f"{preset=}") # ### Plot peak localization From 46337713c13cefa59b9686d2c6648d9ed87b7d54 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 09:55:19 +0200 Subject: [PATCH 075/166] Remove Path().absolute from map_neo_kwargs and fixes --- src/spikeinterface/core/base.py | 4 ++-- src/spikeinterface/extractors/cbin_ibl.py | 2 +- src/spikeinterface/extractors/neoextractors/alphaomega.py | 4 ++-- src/spikeinterface/extractors/neoextractors/axona.py | 4 ++-- src/spikeinterface/extractors/neoextractors/biocam.py | 2 +- src/spikeinterface/extractors/neoextractors/blackrock.py | 6 +++--- src/spikeinterface/extractors/neoextractors/ced.py | 2 +- src/spikeinterface/extractors/neoextractors/edf.py | 4 ++-- src/spikeinterface/extractors/neoextractors/intan.py | 2 +- src/spikeinterface/extractors/neoextractors/maxwell.py | 2 +- src/spikeinterface/extractors/neoextractors/mcsraw.py | 2 +- src/spikeinterface/extractors/neoextractors/mearec.py | 4 ++-- src/spikeinterface/extractors/neoextractors/neuralynx.py | 6 +++--- src/spikeinterface/extractors/neoextractors/nix.py | 2 +- src/spikeinterface/extractors/neoextractors/openephys.py | 6 +++--- src/spikeinterface/extractors/neoextractors/plexon.py | 4 ++-- src/spikeinterface/extractors/neoextractors/spike2.py | 2 +- src/spikeinterface/extractors/neoextractors/spikegadgets.py | 2 +- src/spikeinterface/extractors/neoextractors/spikeglx.py | 2 +- src/spikeinterface/extractors/neoextractors/tdt.py | 2 +- 20 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 829c8b2637..61ba7b535c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -646,10 +646,10 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo if file_path.is_file(): # standard case based on a file (json or pickle) if str(file_path).endswith(".json"): - with open(str(Path(file_path)), "r") as f: + with open(file_path, "r") as f: d = json.load(f) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - with open(str(Path(file_path)), "rb") as f: + with open(file_path, "rb") as f: d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 926009cb1c..1fac418e85 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -107,7 +107,7 @@ def __init__(self, folder_path, load_sync_channel=False): self.set_property("inter_sample_shift", sample_shifts) self._kwargs = { - "folder_path": str(Path(folder_path).resolve.absolute()), + "folder_path": str(Path(folder_path).absolute()), "load_sync_channel": load_sync_channel, } diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index d806a07fe9..11a1869e77 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -39,7 +39,7 @@ def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=Non @classmethod def map_to_neo_kwargs(cls, folder_path, lsx_files=None): neo_kwargs = { - "dirname": str(Path(folder_path).absolute()), + "dirname": folder_path, "lsx_files": lsx_files, } return neo_kwargs @@ -60,7 +60,7 @@ def __init__(self, folder_path): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).absolute())} + neo_kwargs = {"dirname": folder_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index 0772ea0a26..6f91e0062f 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -26,11 +26,11 @@ class AxonaRecordingExtractor(NeoBaseRecordingExtractor): def __init__(self, file_path, all_annotations=False): neo_kwargs = self.map_to_neo_kwargs(file_path) NeoBaseRecordingExtractor.__init__(self, all_annotations=all_annotations, **neo_kwargs) - self._kwargs.update({"file_path": file_path}) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index f3483b1cd5..d5b4de0454 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -69,7 +69,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index f36995e37e..a9ac2bddbe 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -60,7 +60,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs @@ -107,7 +107,7 @@ def __init__( ) self._kwargs = { - "file_path": file_path, + "file_path": str(Path(file_path).absolute()), "sampling_frequency": sampling_frequency, "stream_id": stream_id, "stream_name": stream_name, @@ -115,7 +115,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index 49939c4897..d4f96d7fe3 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -41,7 +41,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 309a686b69..7cb43aa70b 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -29,7 +29,7 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): name = "edf" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) @@ -38,7 +38,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index 330a1e4682..e9765784ab 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -42,7 +42,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 602a054ed8..304c986e30 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path, rec_name=None): - neo_kwargs = {"filename": str(Path(file_path).absolute()), "rec_name": rec_name} + neo_kwargs = {"filename": file_path, "rec_name": rec_name} return neo_kwargs def install_maxwell_plugin(self, force_download=False): diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index 84d0ff1b7d..d3d75f365a 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -46,7 +46,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 4fcdfef006..f6ac3e2393 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -64,7 +64,7 @@ def map_to_neo_kwargs( file_path, ): neo_kwargs = { - "filename": str(Path(file_path).absolute()), + "filename": file_path, "load_spiketrains": False, "load_analogsignal": True, } @@ -90,7 +90,7 @@ def __init__(self, file_path: Union[str, Path]): @classmethod def map_to_neo_kwargs(cls, file_path): neo_kwargs = { - "filename": str(Path(file_path).absolute()), + "filename": file_path, "load_spiketrains": True, "load_analogsignal": False, } diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index 351fbd6e44..ef2a39cbba 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -37,7 +37,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, all_annotation @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).absolute())} + neo_kwargs = {"dirname": folder_path} return neo_kwargs @@ -83,7 +83,7 @@ def __init__( ) self._kwargs = { - "folder_path": folder_path, + "folder_path": str(Path(folder_path).absolute()), "sampling_frequency": sampling_frequency, "stream_id": stream_id, "stream_name": stream_name, @@ -91,7 +91,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).absolute())} + neo_kwargs = {"dirname": folder_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 2cd336083c..43c82cf427 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -44,7 +44,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 26841057c5..c91e3ad684 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -65,7 +65,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).absolute())} + neo_kwargs = {"dirname": folder_path} return neo_kwargs @@ -214,7 +214,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False, experiment_names=None): neo_kwargs = { - "dirname": str(Path(folder_path).absolute()), + "dirname": folder_path, "load_sync_channel": load_sync_channel, "experiment_names": experiment_names, } @@ -248,7 +248,7 @@ def __init__(self, folder_path, block_index=None): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).absolute())} + neo_kwargs = {"dirname": folder_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index d07037aebf..2439854165 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -36,7 +36,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs @@ -67,7 +67,7 @@ def __init__(self, file_path): @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 5bc03aac0d..5a1b1a8bec 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -38,7 +38,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index b5b5ace95d..e014df60be 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -36,7 +36,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": str(Path(file_path).absolute())} + neo_kwargs = {"filename": file_path} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 74e7008d01..dd15260ba1 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -95,7 +95,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): - neo_kwargs = {"dirname": str(Path(folder_path).absolute()), "load_sync_channel": load_sync_channel} + neo_kwargs = {"dirname": folder_path, "load_sync_channel": load_sync_channel} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 6de46498b2..007b22109f 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -41,7 +41,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": str(Path(folder_path).absolute())} + neo_kwargs = {"dirname": folder_path} return neo_kwargs From fe6418baf48fbfddd4d5bcb44accf10b95d54338 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 09:58:02 +0200 Subject: [PATCH 076/166] one more --- src/spikeinterface/extractors/shybridextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index cefa738ae1..130c0ce47e 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -169,7 +169,7 @@ def __init__(self, file_path, sampling_frequency, delimiter=","): if Path(file_path).is_file(): spike_clusters = sbio.SpikeClusters() - spike_clusters.fromCSV(str(Path(file_path).absolute()), None, delimiter=delimiter) + spike_clusters.fromCSV(str(file_path), None, delimiter=delimiter) else: raise FileNotFoundError(f"The ground truth file {file_path} could not be found") From 7632443914f37283b696cdbdba50e326a98bcf0a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 10:00:18 +0200 Subject: [PATCH 077/166] Prepare release 0.98.1 --- doc/releases/0.98.1.rst | 25 +++++++++++++++++++++++++ doc/whatisnew.rst | 7 +++++++ pyproject.toml | 6 +++--- 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 doc/releases/0.98.1.rst diff --git a/doc/releases/0.98.1.rst b/doc/releases/0.98.1.rst new file mode 100644 index 0000000000..4f3dbe2e23 --- /dev/null +++ b/doc/releases/0.98.1.rst @@ -0,0 +1,25 @@ +.. _release0.98.1: + +SpikeInterface 0.98.1 release notes +----------------------------------- + +18th July 2023 + +Minor release with some bug fixes. + +* Make all paths resolved and absolute (#1834) +* Improve Documentation (#1809) +* Fix hdbascan installation in read the docs (#1838) +* Fixed numba.jit and binary num_chan warnings (#1836) +* Fix neo release bug in Mearec (#1835) +* Do not load NP probe in OE if load_sync_channel=True (#1832) +* Cleanup dumping/to_dict (#1831) +* Expose AUCpslit param in KS2+ (#1829) +* Add option relative_to=True (#1820) +* plot_motion: make recording optional, add amplitude_clim and alpha (#1818) +* Fix typo in class attribute for NeuralynxSortingExtractor (#1814) +* Make to_phy write templates.npy with datatype np.float64 as required by phy (#1810) +* Add docs requirements and build read-the-docs documentation faster (#1807) +* Fix has_channel_locations function (#1806) +* Add depth_order kwargs (#1803) + diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 1a61946aa7..21ad89af62 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.98.1.rst releases/0.98.0.rst releases/0.97.1.rst releases/0.97.0.rst @@ -29,6 +30,12 @@ Release notes releases/0.9.1.rst +Version 0.98.1 +============== + +* Minor release with some bug fixes + + Version 0.98.0 ============== diff --git a/pyproject.toml b/pyproject.toml index 1b6d116e4b..191470fd2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.99.0.dev0" +version = "0.98.1" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -156,8 +156,8 @@ docs = [ "hdbscan", # For sorters, probably spikingcircus "numba", # For sorters, probably spikingcircus # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From 5b1baab40a17ca00de2d48fad0b65aae83f468ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jul 2023 08:01:49 +0000 Subject: [PATCH 078/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/releases/0.98.1.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/releases/0.98.1.rst b/doc/releases/0.98.1.rst index 4f3dbe2e23..b713e2fbd2 100644 --- a/doc/releases/0.98.1.rst +++ b/doc/releases/0.98.1.rst @@ -22,4 +22,3 @@ Minor release with some bug fixes. * Add docs requirements and build read-the-docs documentation faster (#1807) * Fix has_channel_locations function (#1806) * Add depth_order kwargs (#1803) - From 78f85727ccfd28d2115c452cd27f35ff9414237e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 10:15:07 +0200 Subject: [PATCH 079/166] Always pass strings to NEO kwargs --- src/spikeinterface/extractors/neoextractors/alphaomega.py | 4 ++-- src/spikeinterface/extractors/neoextractors/axona.py | 2 +- src/spikeinterface/extractors/neoextractors/biocam.py | 2 +- src/spikeinterface/extractors/neoextractors/blackrock.py | 4 ++-- src/spikeinterface/extractors/neoextractors/ced.py | 2 +- src/spikeinterface/extractors/neoextractors/edf.py | 4 ++-- src/spikeinterface/extractors/neoextractors/intan.py | 2 +- src/spikeinterface/extractors/neoextractors/maxwell.py | 2 +- src/spikeinterface/extractors/neoextractors/mcsraw.py | 2 +- src/spikeinterface/extractors/neoextractors/mearec.py | 4 ++-- src/spikeinterface/extractors/neoextractors/neuralynx.py | 4 ++-- src/spikeinterface/extractors/neoextractors/neuroscope.py | 4 ++-- src/spikeinterface/extractors/neoextractors/nix.py | 2 +- src/spikeinterface/extractors/neoextractors/openephys.py | 6 +++--- src/spikeinterface/extractors/neoextractors/plexon.py | 4 ++-- src/spikeinterface/extractors/neoextractors/spike2.py | 2 +- src/spikeinterface/extractors/neoextractors/spikegadgets.py | 2 +- src/spikeinterface/extractors/neoextractors/spikeglx.py | 2 +- src/spikeinterface/extractors/neoextractors/tdt.py | 2 +- 19 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/alphaomega.py b/src/spikeinterface/extractors/neoextractors/alphaomega.py index 11a1869e77..a58b5ab5ec 100644 --- a/src/spikeinterface/extractors/neoextractors/alphaomega.py +++ b/src/spikeinterface/extractors/neoextractors/alphaomega.py @@ -39,7 +39,7 @@ def __init__(self, folder_path, lsx_files=None, stream_id="RAW", stream_name=Non @classmethod def map_to_neo_kwargs(cls, folder_path, lsx_files=None): neo_kwargs = { - "dirname": folder_path, + "dirname": str(folder_path), "lsx_files": lsx_files, } return neo_kwargs @@ -60,7 +60,7 @@ def __init__(self, folder_path): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": folder_path} + neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/axona.py b/src/spikeinterface/extractors/neoextractors/axona.py index 6f91e0062f..6b1d47e4fa 100644 --- a/src/spikeinterface/extractors/neoextractors/axona.py +++ b/src/spikeinterface/extractors/neoextractors/axona.py @@ -30,7 +30,7 @@ def __init__(self, file_path, all_annotations=False): @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index d5b4de0454..3e30cf77ae 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -69,7 +69,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/blackrock.py b/src/spikeinterface/extractors/neoextractors/blackrock.py index a9ac2bddbe..8300e6bc5e 100644 --- a/src/spikeinterface/extractors/neoextractors/blackrock.py +++ b/src/spikeinterface/extractors/neoextractors/blackrock.py @@ -60,7 +60,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs @@ -115,7 +115,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/ced.py b/src/spikeinterface/extractors/neoextractors/ced.py index d4f96d7fe3..2451ca8fe1 100644 --- a/src/spikeinterface/extractors/neoextractors/ced.py +++ b/src/spikeinterface/extractors/neoextractors/ced.py @@ -41,7 +41,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/edf.py b/src/spikeinterface/extractors/neoextractors/edf.py index 7cb43aa70b..5d8c56ee87 100644 --- a/src/spikeinterface/extractors/neoextractors/edf.py +++ b/src/spikeinterface/extractors/neoextractors/edf.py @@ -29,7 +29,7 @@ class EDFRecordingExtractor(NeoBaseRecordingExtractor): name = "edf" def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} NeoBaseRecordingExtractor.__init__( self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs ) @@ -38,7 +38,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/intan.py b/src/spikeinterface/extractors/neoextractors/intan.py index e9765784ab..2a61e7385f 100644 --- a/src/spikeinterface/extractors/neoextractors/intan.py +++ b/src/spikeinterface/extractors/neoextractors/intan.py @@ -42,7 +42,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 304c986e30..ac85dbdf30 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, file_path, rec_name=None): - neo_kwargs = {"filename": file_path, "rec_name": rec_name} + neo_kwargs = {"filename": str(file_path), "rec_name": rec_name} return neo_kwargs def install_maxwell_plugin(self, force_download=False): diff --git a/src/spikeinterface/extractors/neoextractors/mcsraw.py b/src/spikeinterface/extractors/neoextractors/mcsraw.py index d3d75f365a..4b6af54bcd 100644 --- a/src/spikeinterface/extractors/neoextractors/mcsraw.py +++ b/src/spikeinterface/extractors/neoextractors/mcsraw.py @@ -46,7 +46,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index f6ac3e2393..0ec7326f3f 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -64,7 +64,7 @@ def map_to_neo_kwargs( file_path, ): neo_kwargs = { - "filename": file_path, + "filename": str(file_path), "load_spiketrains": False, "load_analogsignal": True, } @@ -90,7 +90,7 @@ def __init__(self, file_path: Union[str, Path]): @classmethod def map_to_neo_kwargs(cls, file_path): neo_kwargs = { - "filename": file_path, + "filename": str(file_path), "load_spiketrains": True, "load_analogsignal": False, } diff --git a/src/spikeinterface/extractors/neoextractors/neuralynx.py b/src/spikeinterface/extractors/neoextractors/neuralynx.py index ef2a39cbba..672602b66c 100644 --- a/src/spikeinterface/extractors/neoextractors/neuralynx.py +++ b/src/spikeinterface/extractors/neoextractors/neuralynx.py @@ -37,7 +37,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, all_annotation @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": folder_path} + neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs @@ -91,7 +91,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": folder_path} + neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 232b073332..a41441b8b7 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -62,9 +62,9 @@ def map_to_neo_kwargs(cls, file_path, xml_file_path=None): # binary_file is the binary file in .dat, .lfp, .eeg if xml_file_path is not None: - neo_kwargs = {"binary_file": file_path, "filename": xml_file_path} + neo_kwargs = {"binary_file": str(file_path), "filename": str(xml_file_path)} else: - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/nix.py b/src/spikeinterface/extractors/neoextractors/nix.py index 43c82cf427..2762e5645b 100644 --- a/src/spikeinterface/extractors/neoextractors/nix.py +++ b/src/spikeinterface/extractors/neoextractors/nix.py @@ -44,7 +44,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index c91e3ad684..a771dc47b1 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -65,7 +65,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": folder_path} + neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs @@ -214,7 +214,7 @@ def __init__( @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False, experiment_names=None): neo_kwargs = { - "dirname": folder_path, + "dirname": str(folder_path), "load_sync_channel": load_sync_channel, "experiment_names": experiment_names, } @@ -248,7 +248,7 @@ def __init__(self, folder_path, block_index=None): @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": folder_path} + neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/plexon.py b/src/spikeinterface/extractors/neoextractors/plexon.py index 2439854165..c3ff59fe82 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon.py +++ b/src/spikeinterface/extractors/neoextractors/plexon.py @@ -36,7 +36,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs @@ -67,7 +67,7 @@ def __init__(self, file_path): @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spike2.py b/src/spikeinterface/extractors/neoextractors/spike2.py index 5a1b1a8bec..af172855ed 100644 --- a/src/spikeinterface/extractors/neoextractors/spike2.py +++ b/src/spikeinterface/extractors/neoextractors/spike2.py @@ -38,7 +38,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations= @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index e014df60be..49d55ca3eb 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -36,7 +36,7 @@ def __init__(self, file_path, stream_id=None, stream_name=None, block_index=None @classmethod def map_to_neo_kwargs(cls, file_path): - neo_kwargs = {"filename": file_path} + neo_kwargs = {"filename": str(file_path)} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index dd15260ba1..8c3b33505d 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -95,7 +95,7 @@ def __init__(self, folder_path, load_sync_channel=False, stream_id=None, stream_ @classmethod def map_to_neo_kwargs(cls, folder_path, load_sync_channel=False): - neo_kwargs = {"dirname": folder_path, "load_sync_channel": load_sync_channel} + neo_kwargs = {"dirname": str(folder_path), "load_sync_channel": load_sync_channel} return neo_kwargs diff --git a/src/spikeinterface/extractors/neoextractors/tdt.py b/src/spikeinterface/extractors/neoextractors/tdt.py index 007b22109f..60cd39c010 100644 --- a/src/spikeinterface/extractors/neoextractors/tdt.py +++ b/src/spikeinterface/extractors/neoextractors/tdt.py @@ -41,7 +41,7 @@ def __init__(self, folder_path, stream_id=None, stream_name=None, block_index=No @classmethod def map_to_neo_kwargs(cls, folder_path): - neo_kwargs = {"dirname": folder_path} + neo_kwargs = {"dirname": str(folder_path)} return neo_kwargs From fe609bcc11e79dc5085c2765ccc6865d0db01c48 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 10:19:13 +0200 Subject: [PATCH 080/166] release 0.98.1 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 191470fd2a..3b0b4e0f2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,8 +139,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ From 833d9f5e19e4ffc913d3baaad28cf6986a0eab56 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 11:02:36 +0200 Subject: [PATCH 081/166] neo and probeinterface version constrain --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3b0b4e0f2d..8e3abaf5cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,11 +20,11 @@ classifiers = [ dependencies = [ "numpy", - "neo>=0.11.1", + "neo>=0.12.0", "joblib", "threadpoolctl", "tqdm", - "probeinterface>=0.2.16", + "probeinterface>=0.2.17", ] [build-system] From 6b09bf1c7322810efeaa373be8e8598fe95aef6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Tue, 18 Jul 2023 11:25:05 +0200 Subject: [PATCH 082/166] Remove warning BinaryRecordingExtractor with `num_chan` --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 53833b01a2..6edf5af16b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -579,7 +579,7 @@ def remove_duplicates_via_matching( f.write(blanck) f.close() - recording = BinaryRecordingExtractor(tmp_filename, num_chan=num_chans, sampling_frequency=fs, dtype="float32") + recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) From bccc462a0c89b588df7c48a8f84b18bd00f24dfc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 11:38:21 +0200 Subject: [PATCH 083/166] refactor wip --- src/spikeinterface/widgets/base.py | 22 ++++-- src/spikeinterface/widgets/unit_locations.py | 52 +++---------- src/spikeinterface/widgets/widget_list.py | 81 ++++++++++---------- 3 files changed, 66 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7b62dc3507..3b708c57d7 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -55,7 +55,7 @@ class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, data_plot=None, backend=None, **backend_kwargs): + def __init__(self, data_plot=None, backend=None, **backend_kwargs, do_plot=True): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot self.backend = backend @@ -72,9 +72,11 @@ def __init__(self, data_plot=None, backend=None, **backend_kwargs): self.backend_kwargs = backend_kwargs_ + if do_plot: + self.do_plot() + + - func = getattr(self, f'plot_{backend}') - func(self) def check_backend(self, backend): @@ -96,11 +98,15 @@ def check_backend(self, backend): # ) def do_plot(self, backend, **backend_kwargs): - backend = self.check_backend(backend) - plotter = self.possible_backends[backend]() - self.check_backend_kwargs(plotter, backend, **backend_kwargs) - plotter.do_plot(self.plot_data, **backend_kwargs) - self.plotter = plotter + # backend = self.check_backend(backend) + # plotter = self.possible_backends[backend]() + # self.check_backend_kwargs(plotter, backend, **backend_kwargs) + # plotter.do_plot(self.plot_data, **backend_kwargs) + # self.plotter = plotter + + func = getattr(self, f'plot_{backend}') + func(self, self.data_plot, self.backend_kwargs) + @classmethod def register_backend(cls, backend_plotter): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index be9d9cacc8..4ea306bad6 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -82,7 +82,7 @@ def __init__( BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) - def plot_matplotlib(self, **backend_kwargs): + def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure from probeinterface.plotting import plot_probe @@ -93,12 +93,12 @@ def plot_matplotlib(self, **backend_kwargs): - dp = to_attr(self.data_plot) + dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) # self.make_mpl_figure(**backend_kwargs) - self.figure, self.axes, self.ax = make_mpl_figure(self.backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(backend_kwargs) unit_locations = dp.unit_locations @@ -171,13 +171,12 @@ def plot_matplotlib(self, **backend_kwargs): if dp.hide_axis: self.ax.axis("off") - def plot_sortingview(self): + def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs = self.backend_kwargs - dp = to_attr(self.data_plot) + dp = to_attr(data_plot) # ensure serializable for sortingview unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) @@ -215,6 +214,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + self.next_data_plot = data_plot.copy() + cm = 1 / 2.54 # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -228,7 +229,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with plt.ioff(): output = widgets.Output() with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() data_plot["unit_ids"] = data_plot["unit_ids"][:1] @@ -265,40 +266,6 @@ def update_widget(self, change): unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call - # data_plot = self.next_data_plot - self.data_plot["unit_ids"] = unit_ids - self.data_plot["plot_all_units"] = True - self.data_plot["plot_legend"] = True - self.data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - self.plot_matplotlib(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() - - - - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids @@ -310,9 +277,10 @@ def __call__(self, change): backend_kwargs["ax"] = self.ax # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - UnitLocationsWidget.plot_matplotlib(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() fig.canvas.flush_events() + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index a6e0896e99..4dbd4b3c68 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -56,25 +56,25 @@ widget_list = [ - AmplitudesWidget, - AllAmplitudesDistributionsWidget, - AutoCorrelogramsWidget, - CrossCorrelogramsWidget, - QualityMetricsWidget, - SpikeLocationsWidget, - SpikesOnTracesWidget, - TemplateMetricsWidget, - MotionWidget, - TemplateSimilarityWidget, - TimeseriesWidget, + # AmplitudesWidget, + # AllAmplitudesDistributionsWidget, + # AutoCorrelogramsWidget, + # CrossCorrelogramsWidget, + # QualityMetricsWidget, + # SpikeLocationsWidget, + # SpikesOnTracesWidget, + # TemplateMetricsWidget, + # MotionWidget, + # TemplateSimilarityWidget, + # TimeseriesWidget, UnitLocationsWidget, - UnitTemplatesWidget, - UnitWaveformsWidget, - UnitWaveformDensityMapWidget, - UnitDepthsWidget, + # UnitTemplatesWidget, + # UnitWaveformsWidget, + # UnitWaveformDensityMapWidget, + # UnitDepthsWidget, # summary - UnitSummaryWidget, - SortingSummaryWidget, + # UnitSummaryWidget, + # SortingSummaryWidget, ] @@ -101,25 +101,28 @@ # make function for all widgets -plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") -plot_all_amplitudes_distributions = define_widget_function_from_class( - AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" -) -plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") -plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") -plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") -plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") -plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") -plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") -plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") -plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") -plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") -plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") -plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") -plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") -plot_unit_waveforms_density_map = define_widget_function_from_class( - UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" -) -plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") -plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") -plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") +# plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") +# plot_all_amplitudes_distributions = define_widget_function_from_class( +# AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" +# ) +# plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") +# plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") +# plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") +# plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") +# plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") +# plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") +# plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") +# plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") +# plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") +# plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") +# plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") +# plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") +# plot_unit_waveforms_density_map = define_widget_function_from_class( +# UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" +# ) +# plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") +# plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") +# plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") + + +plot_unit_locations = UnitLocationsWidget From ab7c8af5be84d3af492411019e6c451e436043bf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 11:39:21 +0200 Subject: [PATCH 084/166] after relase --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e3abaf5cd..e767904fef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.98.1" +version = "0.99.0.dev0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -139,8 +139,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -156,8 +156,8 @@ docs = [ "hdbscan", # For sorters, probably spikingcircus "numba", # For sorters, probably spikingcircus # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From 479a456804edae9cba0ac538929cc88a67e8b9bb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 13:45:58 +0200 Subject: [PATCH 085/166] Pin hdbscan in tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e767904fef..09228c0242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ test = [ # tridesclous "numpy<1.24", "numba", - "hdbscan", + "hdbscan<=0.8.30", # for sortingview backend "sortingview", From 743c3c2f508ce7fe7f4b8e2b32ac7b3eefab8554 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 13:55:08 +0200 Subject: [PATCH 086/166] Pin onse more hdbscan in tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 09228c0242..a640cb42fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ test = [ # tridesclous "numpy<1.24", "numba", - "hdbscan<=0.8.30", + "hdbscan<=0.8.29", # for sortingview backend "sortingview", From d865760fb845d4b206e867ef4906a65b10295989 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 14:02:49 +0200 Subject: [PATCH 087/166] Downgrade Cython --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a640cb42fc..f3a4a23e77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,8 @@ test = [ # tridesclous "numpy<1.24", "numba", - "hdbscan<=0.8.29", + "Cython<3.0.0", + "hdbscan", # for sortingview backend "sortingview", From 6259db230055b0425703d041344fcd49dfc5c7f8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 18 Jul 2023 14:13:14 +0200 Subject: [PATCH 088/166] Downgrade hdbscan as well --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f3a4a23e77..d5cb7dacf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ test = [ "numpy<1.24", "numba", "Cython<3.0.0", - "hdbscan", + "hdbscan<=0.8.29", # for sortingview backend "sortingview", From 750ad2495c7d049d7da1c4e065743c43a467ddc8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 15:02:45 +0200 Subject: [PATCH 089/166] widget wip --- src/spikeinterface/widgets/__init__.py | 44 ++++++------- src/spikeinterface/widgets/base.py | 61 +++++++++---------- .../widgets/tests/test_widgets.py | 49 ++++++++------- src/spikeinterface/widgets/unit_locations.py | 3 +- src/spikeinterface/widgets/widget_list.py | 50 ++++++++------- 5 files changed, 106 insertions(+), 101 deletions(-) diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index 83f4b85fee..bb779ff7fb 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -1,35 +1,35 @@ # check if backend are available -try: - import matplotlib +# try: +# import matplotlib - HAVE_MPL = True -except: - HAVE_MPL = False +# HAVE_MPL = True +# except: +# HAVE_MPL = False -try: - import sortingview +# try: +# import sortingview - HAVE_SV = True -except: - HAVE_SV = False +# HAVE_SV = True +# except: +# HAVE_SV = False -try: - import ipywidgets +# try: +# import ipywidgets - HAVE_IPYW = True -except: - HAVE_IPYW = False +# HAVE_IPYW = True +# except: +# HAVE_IPYW = False -# theses import make the Widget.resgister() at import time -if HAVE_MPL: - import spikeinterface.widgets.matplotlib +# # theses import make the Widget.resgister() at import time +# if HAVE_MPL: +# import spikeinterface.widgets.matplotlib -if HAVE_SV: - import spikeinterface.widgets.sortingview +# if HAVE_SV: +# import spikeinterface.widgets.sortingview -if HAVE_IPYW: - import spikeinterface.widgets.ipywidgets +# if HAVE_IPYW: +# import spikeinterface.widgets.ipywidgets # when importing widget list backend are already registered from .widget_list import * diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 3b708c57d7..17903b495b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -55,12 +55,12 @@ class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, data_plot=None, backend=None, **backend_kwargs, do_plot=True): + def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_kwargs, ): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot - self.backend = backend - + self.backend = self.check_backend(backend) + # check backend kwargs for k in backend_kwargs: if k not in default_backend_kwargs[backend]: raise Exception( @@ -72,18 +72,18 @@ def __init__(self, data_plot=None, backend=None, **backend_kwargs, do_plot=True) self.backend_kwargs = backend_kwargs_ - if do_plot: - self.do_plot() - + if immediate_plot: + self.do_plot(self.backend, **self.backend_kwargs) - - + @classmethod + def get_possible_backends(cls): + return [ k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}") ] def check_backend(self, backend): if backend is None: backend = get_default_plotter_backend() - assert backend in self.possible_backends, ( - f"{backend} backend not available! Available backends are: " f"{list(self.possible_backends.keys())}" + assert backend in self.get_possible_backends(), ( + f"{backend} backend not available! Available backends are: " f"{self.get_possible_backends()}" ) return backend @@ -99,18 +99,13 @@ def check_backend(self, backend): def do_plot(self, backend, **backend_kwargs): # backend = self.check_backend(backend) - # plotter = self.possible_backends[backend]() - # self.check_backend_kwargs(plotter, backend, **backend_kwargs) - # plotter.do_plot(self.plot_data, **backend_kwargs) - # self.plotter = plotter func = getattr(self, f'plot_{backend}') - func(self, self.data_plot, self.backend_kwargs) - + func(data_plot=self.data_plot, **self.backend_kwargs) - @classmethod - def register_backend(cls, backend_plotter): - cls.possible_backends[backend_plotter.backend] = backend_plotter + # @classmethod + # def register_backend(cls, backend_plotter): + # cls.possible_backends[backend_plotter.backend] = backend_plotter @staticmethod def check_extensions(waveform_extractor, extensions): @@ -142,12 +137,12 @@ def check_extensions(waveform_extractor, extensions): # return backend_kwargs_ -def copy_signature(source_fct): - def copy(target_fct): - target_fct.__signature__ = inspect.signature(source_fct) - return target_fct +# def copy_signature(source_fct): +# def copy(target_fct): +# target_fct.__signature__ = inspect.signature(source_fct) +# return target_fct - return copy +# return copy class to_attr(object): @@ -168,14 +163,14 @@ def __getattribute__(self, k): return d[k] -def define_widget_function_from_class(widget_class, name): - @copy_signature(widget_class) - def widget_func(*args, **kwargs): - W = widget_class(*args, **kwargs) - W.do_plot(W.backend, **W.backend_kwargs) - return W.plotter +# def define_widget_function_from_class(widget_class, name): +# @copy_signature(widget_class) +# def widget_func(*args, **kwargs): +# W = widget_class(*args, **kwargs) +# W.do_plot(W.backend, **W.backend_kwargs) +# return W.plotter - widget_func.__doc__ = widget_class.__doc__ - widget_func.__name__ = name +# widget_func.__doc__ = widget_class.__doc__ +# widget_func.__name__ = name - return widget_func +# return widget_func diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 3a60a9d2c7..cb4341f044 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,8 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -from spikeinterface.widgets import HAVE_MPL, HAVE_SV +# from spikeinterface.widgets import HAVE_MPL, HAVE_SV + import spikeinterface.extractors as se import spikeinterface.widgets as sw @@ -68,7 +69,10 @@ def setUpClass(cls): # make sparse waveforms cls.sparsity_radius = compute_sparsity(cls.we, method="radius", radius_um=50) cls.sparsity_best = compute_sparsity(cls.we, method="best_channels", num_channels=5) - cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) + if (cache_folder / "mearec_test_sparse").is_dir(): + cls.we_sparse = load_waveforms(cache_folder / "mearec_test_sparse") + else: + cls.we_sparse = cls.we.save(folder=cache_folder / "mearec_test_sparse", sparsity=cls.sparsity_radius) cls.skip_backends = ["ipywidgets"] @@ -82,7 +86,7 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) def test_plot_timeseries(self): - possible_backends = list(sw.TimeseriesWidget.possible_backends.keys()) + possible_backends = list(sw.TimeseriesWidget.get_possible_backends()) for backend in possible_backends: if ON_GITHUB and backend == "sortingview": continue @@ -119,7 +123,7 @@ def test_plot_timeseries(self): ) def test_plot_unit_waveforms(self): - possible_backends = list(sw.UnitWaveformsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_waveforms(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -143,7 +147,7 @@ def test_plot_unit_waveforms(self): ) def test_plot_unit_templates(self): - possible_backends = list(sw.UnitWaveformsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_templates(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -164,7 +168,7 @@ def test_plot_unit_templates(self): ) def test_plot_unit_waveforms_density_map(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -173,7 +177,7 @@ def test_plot_unit_waveforms_density_map(self): ) def test_plot_unit_waveforms_density_map_sparsity_radius(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -187,7 +191,7 @@ def test_plot_unit_waveforms_density_map_sparsity_radius(self): ) def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): - possible_backends = list(sw.UnitWaveformDensityMapWidget.possible_backends.keys()) + possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:2] @@ -201,7 +205,7 @@ def test_plot_unit_waveforms_density_map_sparsity_None_same_axis(self): ) def test_autocorrelograms(self): - possible_backends = list(sw.AutoCorrelogramsWidget.possible_backends.keys()) + possible_backends = list(sw.AutoCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:4] @@ -215,7 +219,7 @@ def test_autocorrelograms(self): ) def test_crosscorrelogram(self): - possible_backends = list(sw.CrossCorrelogramsWidget.possible_backends.keys()) + possible_backends = list(sw.CrossCorrelogramsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.sorting.unit_ids[:4] @@ -229,7 +233,7 @@ def test_crosscorrelogram(self): ) def test_amplitudes(self): - possible_backends = list(sw.AmplitudesWidget.possible_backends.keys()) + possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_amplitudes(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -247,7 +251,7 @@ def test_amplitudes(self): ) def test_plot_all_amplitudes_distributions(self): - possible_backends = list(sw.AllAmplitudesDistributionsWidget.possible_backends.keys()) + possible_backends = list(sw.AllAmplitudesDistributionsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: unit_ids = self.we.unit_ids[:4] @@ -259,7 +263,7 @@ def test_plot_all_amplitudes_distributions(self): ) def test_unit_locations(self): - possible_backends = list(sw.UnitLocationsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) @@ -268,7 +272,7 @@ def test_unit_locations(self): ) def test_spike_locations(self): - possible_backends = list(sw.SpikeLocationsWidget.possible_backends.keys()) + possible_backends = list(sw.SpikeLocationsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_spike_locations(self.we, with_channel_ids=True, backend=backend, **self.backend_kwargs[backend]) @@ -277,35 +281,35 @@ def test_spike_locations(self): ) def test_similarity(self): - possible_backends = list(sw.TemplateSimilarityWidget.possible_backends.keys()) + possible_backends = list(sw.TemplateSimilarityWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_similarity(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_similarity(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_quality_metrics(self): - possible_backends = list(sw.QualityMetricsWidget.possible_backends.keys()) + possible_backends = list(sw.QualityMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_quality_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_quality_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_template_metrics(self): - possible_backends = list(sw.TemplateMetricsWidget.possible_backends.keys()) + possible_backends = list(sw.TemplateMetricsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_template_metrics(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_template_metrics(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_depths(self): - possible_backends = list(sw.UnitDepthsWidget.possible_backends.keys()) + possible_backends = list(sw.UnitDepthsWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_depths(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_unit_depths(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) def test_plot_unit_summary(self): - possible_backends = list(sw.UnitSummaryWidget.possible_backends.keys()) + possible_backends = list(sw.UnitSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_summary( @@ -316,7 +320,7 @@ def test_plot_unit_summary(self): ) def test_sorting_summary(self): - possible_backends = list(sw.SortingSummaryWidget.possible_backends.keys()) + possible_backends = list(sw.SortingSummaryWidget.get_possible_backends()) for backend in possible_backends: if backend not in self.skip_backends: sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) @@ -339,8 +343,9 @@ def test_sorting_summary(self): # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() - mytest.test_quality_metrics() - mytest.test_template_metrics() + mytest.test_unit_locations() + # mytest.test_quality_metrics() + # mytest.test_template_metrics() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 4ea306bad6..036158cda7 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -79,10 +79,11 @@ def __init__( plot_legend=plot_legend, hide_axis=hide_axis, ) - + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): + print(data_plot, backend_kwargs) import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure from probeinterface.plotting import plot_probe diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 4dbd4b3c68..53f2e7eb62 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,29 +1,30 @@ -from .base import define_widget_function_from_class +# from .base import define_widget_function_from_class +from .base import backend_kwargs_desc # basics -from .timeseries import TimeseriesWidget +# from .timeseries import TimeseriesWidget # waveform -from .unit_waveforms import UnitWaveformsWidget -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +# from .unit_waveforms import UnitWaveformsWidget +# from .unit_templates import UnitTemplatesWidget +# from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg -from .autocorrelograms import AutoCorrelogramsWidget -from .crosscorrelograms import CrossCorrelogramsWidget +# from .autocorrelograms import AutoCorrelogramsWidget +# from .crosscorrelograms import CrossCorrelogramsWidget # peak activity # drift/motion # spikes-traces -from .spikes_on_traces import SpikesOnTracesWidget +# from .spikes_on_traces import SpikesOnTracesWidget # PC related # units on probe from .unit_locations import UnitLocationsWidget -from .spike_locations import SpikeLocationsWidget +# from .spike_locations import SpikeLocationsWidget # unit presence @@ -33,26 +34,26 @@ # correlogram comparison # amplitudes -from .amplitudes import AmplitudesWidget -from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +# from .amplitudes import AmplitudesWidget +# from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget # metrics -from .quality_metrics import QualityMetricsWidget -from .template_metrics import TemplateMetricsWidget +# from .quality_metrics import QualityMetricsWidget +# from .template_metrics import TemplateMetricsWidget # motion/drift -from .motion import MotionWidget +# from .motion import MotionWidget # similarity -from .template_similarity import TemplateSimilarityWidget +# from .template_similarity import TemplateSimilarityWidget -from .unit_depths import UnitDepthsWidget +# from .unit_depths import UnitDepthsWidget # summary -from .unit_summary import UnitSummaryWidget -from .sorting_summary import SortingSummaryWidget +# from .unit_summary import UnitSummaryWidget +# from .sorting_summary import SortingSummaryWidget widget_list = [ @@ -89,13 +90,16 @@ **backend_kwargs: kwargs {backend_kwargs} """ - backend_str = f" {list(wcls.possible_backends.keys())}" + # backend_str = f" {list(wcls.possible_backends.keys())}" + backend_str = f" {wcls.get_possible_backends()}" backend_kwargs_str = "" - for backend, backend_plotter in wcls.possible_backends.items(): - backend_kwargs_desc = backend_plotter.backend_kwargs_desc - if len(backend_kwargs_desc) > 0: + # for backend, backend_plotter in wcls.possible_backends.items(): + for backend in wcls.get_possible_backends(): + # backend_kwargs_desc = backend_plotter.backend_kwargs_desc + kwargs_desc = backend_kwargs_desc[backend] + if len(kwargs_desc) > 0: backend_kwargs_str += f"\n {backend}:\n\n" - for bk, bk_dsc in backend_kwargs_desc.items(): + for bk, bk_dsc in kwargs_desc.items(): backend_kwargs_str += f" * {bk}: {bk_dsc}\n" wcls.__doc__ = wcls_doc.format(backends=backend_str, backend_kwargs=backend_kwargs_str) From 7e5ca37e8fb97cf94cd41971e020ff7da8c7cbae Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 18 Jul 2023 15:53:36 +0200 Subject: [PATCH 090/166] Integrate some comments from Ramon. --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/basesorting.py | 8 +++----- src/spikeinterface/core/numpyextractors.py | 9 ++++++--- src/spikeinterface/core/sortingfolder.py | 13 ++++++------- .../core/tests/test_numpy_extractors.py | 10 ++++++---- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index 0e93eb5877..d44890f844 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -11,7 +11,7 @@ from .numpyextractors import NumpyRecording, NumpySorting, SharedMemorySorting, NumpyEvent, NumpySnippets from .zarrrecordingextractor import ZarrRecordingExtractor, read_zarr, get_default_zarr_compressor from .binaryfolder import BinaryFolderRecording, read_binary_folder -from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder_folder, read_npz_folder +from .sortingfolder import NumpyFolderSorting, NpzFolderSorting, read_numpy_sorting_folder, read_npz_folder from .npysnippetsextractor import NpySnippetsExtractor, read_npy_snippets from .npyfoldersnippets import NpyFolderSnippets, read_npy_snippets_folder diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 0a4d9fd4f6..997b6995ae 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -223,8 +223,6 @@ def _save(self, format="numpy_folder", **save_kwargs): Since v0.98.0 'numpy_folder' is used by defult. From v0.96.0 to 0.97.0 'npz_folder' was the default. - - At the moment only 'npz' is supported. """ if format == "numpy_folder": from .sortingfolder import NumpyFolderSorting @@ -474,7 +472,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac sample_indices = [] unit_indices = [] for u, unit_id in enumerate(self.unit_ids): - spike_times = st = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_times = self.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) sample_indices.append(spike_times) unit_indices.append(np.full(spike_times.size, u, dtype="int64")) @@ -527,7 +525,7 @@ def to_numpy_sorting(self, propagate_cache=True): def to_shared_memory_sorting(self): """ Turn any sorting in a SharedMemorySorting. - Usefull to have it in memory with a unique vector representation and sharable acros processes. + Usefull to have it in memory with a unique vector representation and sharable across processes. """ from .numpyextractors import SharedMemorySorting @@ -538,7 +536,7 @@ def to_multiprocessing(self, n_jobs): """ When necessary turn sorting object into: * NumpySorting when n_jobs=1 - * SharedMemorySorting whe, n_jobs>1 + * SharedMemorySorting when n_jobs>1 If the sorting is already NumpySorting, SharedMemorySorting or NumpyFolderSorting then this return the sortign itself, no transformation so. diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 12120734e1..97f22615df 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -104,8 +104,11 @@ class NumpySorting(BaseSorting): The internal representation is always done with a long "spike vector". - But we have convinient function to instantiate from other sorting object, from time+labels, - from dict of list or from neo. + But we have convenient class methods to instantiate from: + * other sorting object: `NumpySorting.from_sorting()` + * from time+labels: `NumpySorting.from_times_labels()` + * from dict of list: `NumpySorting.from_unit_dict()` + * from neo: `NumpySorting.from_neo_spiketrain_list()` Parameters ---------- @@ -247,7 +250,7 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": sorting = NumpySorting(spikes, sampling_frequency, unit_ids) - # Trick : pupulate the cache with dict that already exists + # Trick : populate the cache with dict that already exists sorting._cached_spike_trains = {seg_ind: d for seg_ind, d in enumerate(units_dict_list)} return sorting diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index 32b983f7e7..49619bca06 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -12,12 +12,11 @@ class NumpyFolderSorting(BaseSorting): """ - NumpyFolderSorting is the new internal format used in spikeinterface (>=0.98.0) for caching - sorting obecjts. + NumpyFolderSorting is the new internal format used in spikeinterface (>=0.99.0) for caching sorting objects. It is a simple folder that contains: - * a file "spike.npy" (numpy formt) with all flatten spikes (using sorting.to_spike_vector()) - * a "numpysorting_info.json" containing sampling_frequenc, unit_ids and num_segments + * a file "spike.npy" (numpy format) with all flatten spikes (using sorting.to_spike_vector()) + * a "numpysorting_info.json" containing sampling_frequency, unit_ids and num_segments * a metadata folder for units properties. It is created with the function: `sorting.save(folder='/myfolder', format="numpy_folder")` @@ -73,7 +72,7 @@ def write_sorting(sorting, save_path): class NpzFolderSorting(NpzSortingExtractor): """ - NpzFolderSorting is the old internal format used in spikeinterface (<=0.97.0) + NpzFolderSorting is the old internal format used in spikeinterface (<=0.98.0) This a folder that contains: @@ -131,7 +130,7 @@ def write_sorting(sorting, save_path): cached.dump(save_path / "npz.json", relative_to=save_path) -read_numpy_sorting_folder_folder = define_function_from_class( - source_class=NumpyFolderSorting, name="read_numpy_sorting_folder_folder" +read_numpy_sorting_folder = define_function_from_class( + source_class=NumpyFolderSorting, name="read_numpy_sorting_folder" ) read_npz_folder = define_function_from_class(source_class=NpzFolderSorting, name="read_npz_folder") diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 36a7585e7c..6c504f3765 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -62,8 +62,10 @@ def test_NumpySorting(): sorting = NumpySorting.from_sorting(other_sorting) # print(sorting) - # TODO test too_dict()/ - # TODO some test on caching + # construct back from kwargs keep the same array + sorting2 = load_extractor(sorting.to_dict()) + assert np.shares_memory(sorting2._cached_spike_vector, sorting._cached_spike_vector) + def test_SharedMemorySorting(): @@ -131,6 +133,6 @@ def test_NumpyEvent(): if __name__ == "__main__": # test_NumpyRecording() - # test_NumpySorting() - test_SharedMemorySorting() + test_NumpySorting() + # test_SharedMemorySorting() # test_NumpyEvent() From 3594a8f6c5ca5eafb5ef5836d7606ed7ad1354f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jul 2023 13:53:58 +0000 Subject: [PATCH 091/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_numpy_extractors.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 6c504f3765..4a5bffbc05 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -65,7 +65,6 @@ def test_NumpySorting(): # construct back from kwargs keep the same array sorting2 = load_extractor(sorting.to_dict()) assert np.shares_memory(sorting2._cached_spike_vector, sorting._cached_spike_vector) - def test_SharedMemorySorting(): @@ -134,5 +133,5 @@ def test_NumpyEvent(): if __name__ == "__main__": # test_NumpyRecording() test_NumpySorting() - # test_SharedMemorySorting() + # test_SharedMemorySorting() # test_NumpyEvent() From 3f236c703e830c931223bab7fdda5fa0de84cd59 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 09:39:29 +0200 Subject: [PATCH 092/166] widgets utils files --- .../widgets/ipywidgets_utils.py | 105 ++++++++++++++++++ .../widgets/matplotlib_utils.py | 75 +++++++++++++ .../widgets/sortingview_utils.py | 95 ++++++++++++++++ 3 files changed, 275 insertions(+) create mode 100644 src/spikeinterface/widgets/ipywidgets_utils.py create mode 100644 src/spikeinterface/widgets/matplotlib_utils.py create mode 100644 src/spikeinterface/widgets/sortingview_utils.py diff --git a/src/spikeinterface/widgets/ipywidgets_utils.py b/src/spikeinterface/widgets/ipywidgets_utils.py new file mode 100644 index 0000000000..4490cc3365 --- /dev/null +++ b/src/spikeinterface/widgets/ipywidgets_utils.py @@ -0,0 +1,105 @@ +import ipywidgets.widgets as widgets +import numpy as np + + + +def check_ipywidget_backend(): + import matplotlib + mpl_backend = matplotlib.get_backend() + assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" + + + +def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): + time_slider = widgets.FloatSlider( + orientation="horizontal", + description="time:", + value=time_range[0], + min=t_start, + max=t_stop, + continuous_update=False, + layout=widgets.Layout(width=f"{width_cm}cm"), + ) + layer_selector = widgets.Dropdown(description="layer", options=layer_keys) + segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) + window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") + mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) + all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) + + controller = { + "layer_key": layer_selector, + "segment_index": segment_selector, + "window": window_sizer, + "t_start": time_slider, + "mode": mode_selector, + "all_layers": all_layers, + } + widget = widgets.VBox( + [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] + ) + + return widget, controller + + +def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): + unit_label = widgets.Label(value="units:") + + unit_selector = widgets.SelectMultiple( + options=all_unit_ids, + value=list(unit_ids), + disabled=False, + layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), + ) + + controller = {"unit_ids": unit_selector} + widget = widgets.VBox([unit_label, unit_selector]) + + return widget, controller + + +def make_channel_controller(recording, width_cm, height_cm): + channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) + channel_selector = widgets.IntRangeSlider( + value=[0, recording.get_num_channels()], + min=0, + max=recording.get_num_channels(), + step=1, + disabled=False, + continuous_update=False, + orientation="vertical", + readout=True, + readout_format="d", + layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), + ) + + controller = {"channel_inds": channel_selector} + widget = widgets.VBox([channel_label, channel_selector]) + + return widget, controller + + +def make_scale_controller(width_cm, height_cm): + scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) + + plus_selector = widgets.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase scale", + icon="arrow-up", + layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + ) + + minus_selector = widgets.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease scale", + icon="arrow-down", + layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + ) + + controller = {"plus": plus_selector, "minus": minus_selector} + widget = widgets.VBox([scale_label, plus_selector, minus_selector]) + + return widget, controller diff --git a/src/spikeinterface/widgets/matplotlib_utils.py b/src/spikeinterface/widgets/matplotlib_utils.py new file mode 100644 index 0000000000..6ccaaf5840 --- /dev/null +++ b/src/spikeinterface/widgets/matplotlib_utils.py @@ -0,0 +1,75 @@ +import matplotlib.pyplot as plt +import numpy as np + + +def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figsize=None, figtitle=None): + """ + figure/ax/axes : only one of then can be not None + """ + if figure is not None: + assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" + if num_axes is None: + ax = figure.add_subplot(111) + axes = np.array([[ax]]) + else: + assert ncols is not None + axes = [] + nrows = int(np.ceil(num_axes / ncols)) + axes = np.full((nrows, ncols), fill_value=None, dtype=object) + for i in range(num_axes): + ax = figure.add_subplot(nrows, ncols, i + 1) + r = i // ncols + c = i % ncols + axes[r, c] = ax + elif ax is not None: + assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" + figure = ax.get_figure() + axes = np.array([[ax]]) + elif axes is not None: + assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" + axes = np.asarray(axes) + figure = axes.flatten()[0].get_figure() + else: + # 'figure/ax/axes are all None + if num_axes is None: + # one fig with one ax + figure, ax = plt.subplots(figsize=figsize) + axes = np.array([[ax]]) + else: + if num_axes == 0: + # one figure without plots (diffred subplot creation with + figure = plt.figure(figsize=figsize) + ax = None + axes = None + elif num_axes == 1: + figure = plt.figure(figsize=figsize) + ax = figure.add_subplot(111) + axes = np.array([[ax]]) + else: + assert ncols is not None + if num_axes < ncols: + ncols = num_axes + nrows = int(np.ceil(num_axes / ncols)) + figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) + ax = None + # remove extra axes + if ncols * nrows > num_axes: + for i, extra_ax in enumerate(axes.flatten()): + if i >= num_axes: + extra_ax.remove() + r = i // ncols + c = i % ncols + axes[r, c] = None + + if figtitle is not None: + figure.suptitle(figtitle) + + return figure, axes, ax + + # self.figure = figure + # self.ax = ax + # axes is always a 2D array of ax + # self.axes = axes + + # if figtitle is not None: + # self.figure.suptitle(figtitle) \ No newline at end of file diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py new file mode 100644 index 0000000000..8a4a8f3169 --- /dev/null +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -0,0 +1,95 @@ +import numpy as np + +from ..core.core_tools import check_json + + + + +sortingview_backend_kwargs_desc = { + "generate_url": "If True, the figurl URL is generated and printed. Default True", + "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", + "figlabel": "The figurl figure label. Default None", + "height": "The height of the sortingview View in jupyter. Default None", +} +sortingview_default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} + + + +def make_serializable(*args): + dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} + serializable_dict = check_json(dict_to_serialize) + returns = () + for i in range(len(args) - 1): + returns += (serializable_dict[str(i)],) + if len(returns) == 1: + returns = returns[0] + return returns + +def is_notebook() -> bool: + try: + shell = get_ipython().__class__.__name__ + if shell == "ZMQInteractiveShell": + return True # Jupyter notebook or qtconsole + elif shell == "TerminalInteractiveShell": + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False + +def handle_display_and_url(widget, view, **backend_kwargs): + url = None + if is_notebook() and backend_kwargs["display"]: + display(view.jupyter(height=backend_kwargs["height"])) + if backend_kwargs["generate_url"]: + figlabel = backend_kwargs.get("figlabel") + if figlabel is None: + figlabel = widget.default_label + url = view.url(label=figlabel) + print(url) + + return url + + + + +def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): + import sortingview.views as vv + + if unit_properties is None: + ut_columns = [] + ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] + else: + ut_columns = [] + ut_rows = [] + values = {} + valid_unit_properties = [] + for prop_name in unit_properties: + property_values = sorting.get_property(prop_name) + # make dtype available + val0 = np.array(property_values[0]) + if val0.dtype.kind in ("i", "u"): + dtype = "int" + elif val0.dtype.kind in ("U", "S"): + dtype = "str" + elif val0.dtype.kind == "f": + dtype = "float" + elif val0.dtype.kind == "b": + dtype = "bool" + else: + print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") + continue + ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) + valid_unit_properties.append(prop_name) + + for ui, unit in enumerate(sorting.unit_ids): + for prop_name in valid_unit_properties: + property_values = sorting.get_property(prop_name) + val0 = property_values[0] + if np.isnan(property_values[ui]): + continue + values[prop_name] = property_values[ui] + ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) + + v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) + return v_units_table From a149f7d7a07f0c54a53018e83d6ad89e703e6a30 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 09:42:33 +0200 Subject: [PATCH 093/166] Try hdbscan latest release --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5cb7dacf6..e767904fef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,8 +120,7 @@ test = [ # tridesclous "numpy<1.24", "numba", - "Cython<3.0.0", - "hdbscan<=0.8.29", + "hdbscan", # for sortingview backend "sortingview", From c2cab6a42a384d0f91d5c3bbc9a79cb9c91e2da1 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jul 2023 09:58:25 +0200 Subject: [PATCH 094/166] fix dependencies --- docs_rtd.yml | 1 - environment_rtd.yml | 18 ------------------ pyproject.toml | 4 ++-- 3 files changed, 2 insertions(+), 21 deletions(-) delete mode 100644 environment_rtd.yml diff --git a/docs_rtd.yml b/docs_rtd.yml index 975aafb46b..c4e1fb378c 100644 --- a/docs_rtd.yml +++ b/docs_rtd.yml @@ -5,6 +5,5 @@ dependencies: - python=3.10 - pip - datalad - - hdbscan - pip: - -e .[docs] diff --git a/environment_rtd.yml b/environment_rtd.yml deleted file mode 100644 index 5e4b4eb92a..0000000000 --- a/environment_rtd.yml +++ /dev/null @@ -1,18 +0,0 @@ -channels: - - conda-forge - - defaults -dependencies: - - python=3.10 - - pip - - datalad - - numpy=1.23 - - pip: - - sphinx-gallery - - sphinx_rtd_theme - - numpydoc - - MEArec>=1.7.1 - - hdbscan - - numba - - git+https://github.com/NeuralEnsemble/python-neo.git - - git+https://github.com/SpikeInterface/probeinterface.git - - git+https://github.com/SpikeInterface/spikeinterface.git#egg=spikeinterface[full,widgets] diff --git a/pyproject.toml b/pyproject.toml index e767904fef..e84f44cd85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ test = [ # tridesclous "numpy<1.24", "numba", - "hdbscan", + "hdbscan">=0.8.33, # for sortingview backend "sortingview", @@ -153,7 +153,7 @@ docs = [ "MEArec", # Use as an example "datalad==0.16.2", # Download mearec data, not sure if needed as is installed with conda as well because of git-annex "pandas", # Don't know where this is needed - "hdbscan", # For sorters, probably spikingcircus + "hdbscan">=0.8.33, # For sorters, probably spikingcircus "numba", # For sorters, probably spikingcircus # for release we need pypi, so this needs to be commented "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version From af931b798250649943f008e20378a8a3b79c6303 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jul 2023 10:01:36 +0200 Subject: [PATCH 095/166] some useful comment that should trigger full test --- pyproject.toml | 4 ++-- src/spikeinterface/core/baserecording.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e84f44cd85..0c56e1125b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ test = [ # tridesclous "numpy<1.24", "numba", - "hdbscan">=0.8.33, + "hdbscan>=0.8.33", # Previous version had a broken wheel # for sortingview backend "sortingview", @@ -153,7 +153,7 @@ docs = [ "MEArec", # Use as an example "datalad==0.16.2", # Download mearec data, not sure if needed as is installed with conda as well because of git-annex "pandas", # Don't know where this is needed - "hdbscan">=0.8.33, # For sorters, probably spikingcircus + "hdbscan>=0.8.33", # For sorters, probably spikingcircus "numba", # For sorters, probably spikingcircus # for release we need pypi, so this needs to be commented "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8c24e4e624..e7166def75 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -445,6 +445,8 @@ def _save(self, format="binary", **save_kwargs): from .binaryrecordingextractor import BinaryRecordingExtractor + # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading + # See the __init__ of `BinaryFolderRecording` binary_rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=self.get_sampling_frequency(), From cc66afa8a40a3ec782fb4819b071e4f9c02f5025 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jul 2023 10:40:40 +0200 Subject: [PATCH 096/166] Fix Mearec handling of new arguments before neo release 0.13 (#1848) * fix neo version * drop hard to mantain comments --- .../extractors/neoextractors/mearec.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 0ec7326f3f..7dda9175f5 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -8,8 +8,8 @@ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor -def drop_neo_arguments_in_version_0_11_0(neo_kwargs): - # Temporary function until neo version 0.12.0 is released +def drop_invalid_neo_arguments_before_version_0_13_0(neo_kwargs): + # Temporary function until neo version 0.13.0 is released from packaging.version import parse as parse_version from importlib.metadata import version @@ -17,7 +17,7 @@ def drop_neo_arguments_in_version_0_11_0(neo_kwargs): minor_version = parse_version(neo_version).minor # The possibility of loading only spike_trains or only analog_signals is not present in neo <= 0.11.0 - if minor_version < 12: + if minor_version < 13: neo_kwargs.pop("load_spiketrains") neo_kwargs.pop("load_analogsignal") @@ -68,8 +68,7 @@ def map_to_neo_kwargs( "load_spiketrains": False, "load_analogsignal": True, } - # The possibility of loading only spike_trains or only analog_signals will be added in neo version 0.12.0 - neo_kwargs = drop_neo_arguments_in_version_0_11_0(neo_kwargs=neo_kwargs) + neo_kwargs = drop_invalid_neo_arguments_before_version_0_13_0(neo_kwargs=neo_kwargs) return neo_kwargs @@ -94,8 +93,7 @@ def map_to_neo_kwargs(cls, file_path): "load_spiketrains": True, "load_analogsignal": False, } - # The possibility of loading only spike_trains or only analog_signals will be added in neo version 0.12.0 - neo_kwargs = drop_neo_arguments_in_version_0_11_0(neo_kwargs=neo_kwargs) + neo_kwargs = drop_invalid_neo_arguments_before_version_0_13_0(neo_kwargs=neo_kwargs) return neo_kwargs From a3b7cebd97f8cca6516e698002638af0d2400e73 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 19 Jul 2023 10:51:50 +0200 Subject: [PATCH 097/166] changes in tridesclous no longer require upper bound numpy (#1850) Co-authored-by: Alessio Buccino --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0c56e1125b..bcdd76a3df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,6 @@ test = [ "huggingface_hub", # tridesclous - "numpy<1.24", "numba", "hdbscan>=0.8.33", # Previous version had a broken wheel @@ -130,7 +129,7 @@ test = [ "datalad==0.16.2", ## install tridesclous for testing ## - "tridesclous>=1.6.6.1", + "tridesclous>=1.6.7", ## sliding_nn "pymde", From 8be32db05248468439a4df5822dcac430b64a130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Wed, 19 Jul 2023 11:13:46 +0200 Subject: [PATCH 098/166] Remove warning (#1843) * Remove warning BinaryRecordingExtractor with `num_chan` --------- Co-authored-by: Alessio Buccino Co-authored-by: Heberto Mayorquin --- .../sortingcomponents/clustering/clustering_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 53833b01a2..6edf5af16b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -579,7 +579,7 @@ def remove_duplicates_via_matching( f.write(blanck) f.close() - recording = BinaryRecordingExtractor(tmp_filename, num_chan=num_chans, sampling_frequency=fs, dtype="float32") + recording = BinaryRecordingExtractor(tmp_filename, num_channels=num_chans, sampling_frequency=fs, dtype="float32") recording.annotate(is_filtered=True) margin = 2 * max(waveform_extractor.nbefore, waveform_extractor.nafter) From a91409a6f09fb74a561a8317f97804f53300bad0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 11:23:51 +0200 Subject: [PATCH 099/166] Prepare 0.98.2 --- doc/releases/0.98.2.rst | 13 +++++++++++++ doc/whatisnew.rst | 7 +++++++ pyproject.toml | 10 +++++----- 3 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 doc/releases/0.98.2.rst diff --git a/doc/releases/0.98.2.rst b/doc/releases/0.98.2.rst new file mode 100644 index 0000000000..d60a3e53a3 --- /dev/null +++ b/doc/releases/0.98.2.rst @@ -0,0 +1,13 @@ +.. _release0.98.2: + +SpikeInterface 0.98.2 release notes +----------------------------------- + +19th July 2023 + +Minor release with some bug fixes. + +* Remove warning (#1843) +* Fix Mearec handling of new arguments before neo release 0.13 (#1848) +* Fix full tests by updating hdbscan version (#1849) +* Relax numpy upper bound and update tridesclous dependency (#1850) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 21ad89af62..8b984e2510 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.98.2.rst releases/0.98.1.rst releases/0.98.0.rst releases/0.97.1.rst @@ -30,6 +31,12 @@ Release notes releases/0.9.1.rst +Version 0.98.2 +============== + +* Minor release with some bug fixes + + Version 0.98.1 ============== diff --git a/pyproject.toml b/pyproject.toml index 0c56e1125b..38ef09ff0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.99.0.dev0" +version = "0.98.2" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -139,8 +139,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -156,8 +156,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters, probably spikingcircus "numba", # For sorters, probably spikingcircus # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From 84170e37b929ee835bb084e93e4b53d5b168178b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 12:37:42 +0200 Subject: [PATCH 100/166] wip refactor widgets --- src/spikeinterface/widgets/base.py | 3 ++- src/spikeinterface/widgets/sortingview_utils.py | 7 ++++--- src/spikeinterface/widgets/tests/test_widgets.py | 1 + src/spikeinterface/widgets/unit_locations.py | 9 +++++---- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 17903b495b..f95004efb9 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -73,6 +73,7 @@ def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_ self.backend_kwargs = backend_kwargs_ if immediate_plot: + print('immediate_plot', self.backend, self.backend_kwargs) self.do_plot(self.backend, **self.backend_kwargs) @classmethod @@ -101,7 +102,7 @@ def do_plot(self, backend, **backend_kwargs): # backend = self.check_backend(backend) func = getattr(self, f'plot_{backend}') - func(data_plot=self.data_plot, **self.backend_kwargs) + func(self.data_plot, **self.backend_kwargs) # @classmethod # def register_backend(cls, backend_plotter): diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py index 8a4a8f3169..90dfcb77a3 100644 --- a/src/spikeinterface/widgets/sortingview_utils.py +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -16,10 +16,10 @@ def make_serializable(*args): - dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} + dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) returns = () - for i in range(len(args) - 1): + for i in range(len(args)): returns += (serializable_dict[str(i)],) if len(returns) == 1: returns = returns[0] @@ -44,7 +44,8 @@ def handle_display_and_url(widget, view, **backend_kwargs): if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: - figlabel = widget.default_label + # figlabel = widget.default_label + figlabel = "" url = view.url(label=figlabel) print(url) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index cb4341f044..1dff04d334 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -36,6 +36,7 @@ else: cache_folder = Path("cache_folder") / "widgets" +print(cache_folder) ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 036158cda7..e87f553072 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -83,7 +83,6 @@ def __init__( BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - print(data_plot, backend_kwargs) import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure from probeinterface.plotting import plot_probe @@ -99,7 +98,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # self.make_mpl_figure(**backend_kwargs) - self.figure, self.axes, self.ax = make_mpl_figure(backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) unit_locations = dp.unit_locations @@ -180,6 +179,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) # ensure serializable for sortingview + print(dp.unit_ids, dp.channel_ids) + print(make_serializable(dp.unit_ids, dp.channel_ids)) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} @@ -256,10 +257,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self.updater(None) + self.update_widget(None) if backend_kwargs["display"]: - self.check_backend() + # self.check_backend() display(self.widget) def update_widget(self, change): From e617165a551578d63ad020f10bfcdb6a199098c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 16:10:57 +0200 Subject: [PATCH 101/166] Drop figurl-jupyter dependency --- pyproject.toml | 1 - src/spikeinterface/widgets/sortingview/base_sortingview.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bcdd76a3df..44e5b8c288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,6 @@ widgets = [ "ipympl", "ipywidgets", "sortingview>=0.11.15", - "figurl-jupyter" ] test_core = [ diff --git a/src/spikeinterface/widgets/sortingview/base_sortingview.py b/src/spikeinterface/widgets/sortingview/base_sortingview.py index c42da0fba3..8c015b87d1 100644 --- a/src/spikeinterface/widgets/sortingview/base_sortingview.py +++ b/src/spikeinterface/widgets/sortingview/base_sortingview.py @@ -43,8 +43,9 @@ def is_notebook() -> bool: def handle_display_and_url(self, view, **backend_kwargs): self.set_view(view) - if self.is_notebook() and backend_kwargs["display"]: - display(self.view.jupyter(height=backend_kwargs["height"])) + # figurl_jupyter is broken. Comment it out for now. + # if self.is_notebook() and backend_kwargs["display"]: + # display(self.view.jupyter(height=backend_kwargs["height"])) if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: From 5f9e0c9f1e558e1f27aab39aae3c5a955bb144a0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 17:03:17 +0200 Subject: [PATCH 102/166] widget refactor : AllAmplitudesDistributionsWidget and AmplitudesWidget --- .../widgets/all_amplitudes_distributions.py | 41 +++- src/spikeinterface/widgets/amplitudes.py | 183 +++++++++++++++++- src/spikeinterface/widgets/base.py | 7 +- .../widgets/tests/test_widgets.py | 3 +- src/spikeinterface/widgets/unit_locations.py | 78 ++++---- src/spikeinterface/widgets/widget_list.py | 11 +- 6 files changed, 273 insertions(+), 50 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index d1a0acfe1e..18585a4f96 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.waveform_extractor import WaveformExtractor @@ -47,3 +47,42 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + + unit_amps = [] + for i, unit_id in enumerate(dp.unit_ids): + amps = [] + for segment_index in range(dp.num_segments): + amps.append(dp.amplitudes[segment_index][unit_id]) + amps = np.concatenate(amps) + unit_amps.append(amps) + parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) + + for i, pc in enumerate(parts["bodies"]): + color = dp.unit_colors[dp.unit_ids[i]] + pc.set_facecolor(color) + pc.set_edgecolor("black") + pc.set_alpha(1) + + ax.set_xticks(np.arange(len(dp.unit_ids)) + 1) + ax.set_xticklabels([str(unit_id) for unit_id in dp.unit_ids]) + + ylims = ax.get_ylim() + if np.max(ylims) < 0: + ax.set_ylim(min(ylims), 0) + if np.min(ylims) > 0: + ax.set_ylim(0, max(ylims)) \ No newline at end of file diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 833bdf2b06..7c76d26204 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_some_colors from ..core.waveform_extractor import WaveformExtractor @@ -112,3 +112,184 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + + + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + if dp.plot_histograms: + assert np.asarray(axes).size == 2 + else: + assert np.asarray(axes).size == 1 + elif backend_kwargs["ax"] is not None: + assert not dp.plot_histograms + else: + if dp.plot_histograms: + backend_kwargs["num_axes"] = 2 + backend_kwargs["ncols"] = 2 + else: + backend_kwargs["num_axes"] = None + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + scatter_ax = self.axes.flatten()[0] + + for unit_id in dp.unit_ids: + spiketrains = dp.spiketrains[unit_id] + amps = dp.amplitudes[unit_id] + scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) + + if dp.plot_histograms: + if dp.bins is None: + bins = int(len(spiketrains) / 30) + else: + bins = dp.bins + ax_hist = self.axes.flatten()[1] + ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) + + if dp.plot_histograms: + ax_hist = self.axes.flatten()[1] + ax_hist.set_ylim(scatter_ax.get_ylim()) + ax_hist.axis("off") + self.figure.tight_layout() + + if dp.plot_legend: + # if self.legend is not None: + if hasattr(self, 'legend') and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + scatter_ax.set_xlim(0, dp.total_duration) + scatter_ax.set_xlabel("Times [s]") + scatter_ax.set_ylabel(f"Amplitude") + scatter_ax.spines["top"].set_visible(False) + scatter_ax.spines["right"].set_visible(False) + self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + we = data_plot["waveform_extractor"] + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + # fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + plot_histograms = widgets.Checkbox( + value=data_plot["plot_histograms"], + description="plot histograms", + disabled=False, + ) + + footer = plot_histograms + + self.controller = {"plot_histograms": plot_histograms} + self.controller.update(unit_controller) + + # mpl_plotter = MplAmplitudesPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) + for w in self.controller.values(): + # w.observe(self.updater) + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + # center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer + center=self.figure.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + # self.fig.clear() + self.figure.clear() + + unit_ids = self.controller["unit_ids"].value + plot_histograms = self.controller["plot_histograms"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_histograms"] = plot_histograms + + backend_kwargs = {} + # backend_kwargs["figure"] = self.fig + backend_kwargs["figure"] = self.figure + backend_kwargs["axes"] = None + backend_kwargs["ax"] = None + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + sa_items = [ + vv.SpikeAmplitudesItem( + unit_id=u, + spike_times_sec=dp.spiketrains[u].astype("float32"), + spike_amplitudes=dp.amplitudes[u].astype("float32"), + ) + for u in unit_ids + ] + + # v_spike_amplitudes = vv.SpikeAmplitudes( + self.view = vv.SpikeAmplitudes( + start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector + ) + + # self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index f95004efb9..7b0ba0454e 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -58,16 +58,17 @@ class BaseWidget: def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_kwargs, ): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot - self.backend = self.check_backend(backend) + backend = self.check_backend(backend) + self.backend = backend # check backend kwargs for k in backend_kwargs: if k not in default_backend_kwargs[backend]: raise Exception( f"{k} is not a valid plot argument or backend keyword argument. " - f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" + f"Possible backend keyword arguments for {backend} are: {list(default_backend_kwargs[backend].keys())}" ) - backend_kwargs_ = default_backend_kwargs[backend].copy() + backend_kwargs_ = default_backend_kwargs[self.backend].copy() backend_kwargs_.update(backend_kwargs) self.backend_kwargs = backend_kwargs_ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 1dff04d334..4ddec4134b 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -344,9 +344,10 @@ def test_sorting_summary(self): # mytest.test_plot_unit_depths() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_summary() - mytest.test_unit_locations() + # mytest.test_unit_locations() # mytest.test_quality_metrics() # mytest.test_template_metrics() + mytest.test_amplitudes() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index e87f553072..725a4c3023 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -171,51 +171,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.hide_axis: self.ax.axis("off") - def plot_sortingview(self, data_plot, **backend_kwargs): - import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url - - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - print(dp.unit_ids, dp.channel_ids) - print(make_serializable(dp.unit_ids, dp.channel_ids)) - unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - - unit_items = [] - for unit_id in unit_ids: - unit_items.append( - vv.UnitLocationsItem( - unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) - ) - ) - v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - self.view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], - ) - else: - self.view = v_unit_locations - - # self.handle_display_and_url(view, **backend_kwargs) - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + check_ipywidget_backend() + self.next_data_plot = data_plot.copy() cm = 1 / 2.54 @@ -248,7 +214,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # w.observe(self.updater) for w in self.controller.values(): - w.observe(self.update_widget) + w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( center=fig.canvas, @@ -257,13 +223,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self.update_widget(None) + self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) - def update_widget(self, change): + def _update_ipywidget(self, change): self.ax.clear() unit_ids = self.controller["unit_ids"].value @@ -284,5 +249,38 @@ def update_widget(self, change): fig.canvas.draw() fig.canvas.flush_events() + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # ensure serializable for sortingview + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + + unit_items = [] + for unit_id in unit_ids: + unit_items.append( + vv.UnitLocationsItem( + unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) + ) + ) + + v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], + ) + else: + self.view = v_unit_locations + + # self.handle_display_and_url(view, **backend_kwargs) + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 53f2e7eb62..52ee03ebca 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -34,8 +34,8 @@ # correlogram comparison # amplitudes -# from .amplitudes import AmplitudesWidget -# from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +from .amplitudes import AmplitudesWidget +from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget # metrics # from .quality_metrics import QualityMetricsWidget @@ -57,8 +57,8 @@ widget_list = [ - # AmplitudesWidget, - # AllAmplitudesDistributionsWidget, + AmplitudesWidget, + AllAmplitudesDistributionsWidget, # AutoCorrelogramsWidget, # CrossCorrelogramsWidget, # QualityMetricsWidget, @@ -129,4 +129,7 @@ # plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") +plot_amplitudes = AmplitudesWidget +plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_unit_locations = UnitLocationsWidget + From 4170552e8947920743c5445cb7f093968148f547 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 17:21:25 +0200 Subject: [PATCH 103/166] refactor widgets : AutoCorrelogramsWidget + CrossCorrelogramsWidget --- .../widgets/autocorrelograms.py | 59 +++++++++++++++- .../widgets/crosscorrelograms.py | 70 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 8 ++- 3 files changed, 131 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index 701817e168..f07246efa6 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -1,11 +1,68 @@ +from .base import BaseWidget, to_attr + from .crosscorrelograms import CrossCorrelogramsWidget class AutoCorrelogramsWidget(CrossCorrelogramsWidget): - possible_backends = {} + # possible_backends = {} def __init__(self, *args, **kargs): CrossCorrelogramsWidget.__init__(self, *args, **kargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["num_axes"] = len(dp.unit_ids) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + + bins = dp.bins + unit_ids = dp.unit_ids + correlograms = dp.correlograms + bin_width = bins[1] - bins[0] + + for i, unit_id in enumerate(unit_ids): + ccg = correlograms[i, i] + ax = self.axes.flatten()[i] + if dp.unit_colors is None: + color = "g" + else: + color = dp.unit_colors[unit_id] + ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") + ax.set_title(str(unit_id)) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + ac_items = [] + for i in range(len(unit_ids)): + for j in range(i, len(unit_ids)): + if i == j: + ac_items.append( + vv.AutocorrelogramItem( + unit_id=unit_ids[i], + bin_edges_sec=(dp.bins / 1000.0).astype("float32"), + bin_counts=dp.correlograms[i, j].astype("int32"), + ) + ) + + self.view = vv.Autocorrelograms(autocorrelograms=ac_items) + + # self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) + # return v_autocorrelograms + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 8481c8ef0d..eed76c3e04 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor from ..core.basesorting import BaseSorting from ..postprocessing import compute_correlograms @@ -27,7 +27,7 @@ class CrossCorrelogramsWidget(BaseWidget): If given, a dictionary with unit ids as keys and colors as values, default None """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -65,3 +65,69 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["ncols"] = len(dp.unit_ids) + backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + assert self.axes.ndim == 2 + + bins = dp.bins + unit_ids = dp.unit_ids + correlograms = dp.correlograms + bin_width = bins[1] - bins[0] + + for i, unit_id1 in enumerate(unit_ids): + for j, unit_id2 in enumerate(unit_ids): + ccg = correlograms[i, j] + ax = self.axes[i, j] + if i == j: + if dp.unit_colors is None: + color = "g" + else: + color = dp.unit_colors[unit_id1] + else: + color = "k" + ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") + + for i, unit_id in enumerate(unit_ids): + self.axes[0, i].set_title(str(unit_id)) + self.axes[-1, i].set_xlabel("CCG (ms)") + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + cc_items = [] + for i in range(len(unit_ids)): + for j in range(i, len(unit_ids)): + cc_items.append( + vv.CrossCorrelogramItem( + unit_id1=unit_ids[i], + unit_id2=unit_ids[j], + bin_edges_sec=(dp.bins / 1000.0).astype("float32"), + bin_counts=dp.correlograms[i, j].astype("int32"), + ) + ) + + self.view = vv.CrossCorrelograms( + cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector + ) + + # self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) + # return v_cross_correlograms + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 52ee03ebca..fb3a611c60 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,8 +10,8 @@ # from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg -# from .autocorrelograms import AutoCorrelogramsWidget -# from .crosscorrelograms import CrossCorrelogramsWidget +from .autocorrelograms import AutoCorrelogramsWidget +from .crosscorrelograms import CrossCorrelogramsWidget # peak activity @@ -59,7 +59,7 @@ widget_list = [ AmplitudesWidget, AllAmplitudesDistributionsWidget, - # AutoCorrelogramsWidget, + AutoCorrelogramsWidget, # CrossCorrelogramsWidget, # QualityMetricsWidget, # SpikeLocationsWidget, @@ -132,4 +132,6 @@ plot_amplitudes = AmplitudesWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_unit_locations = UnitLocationsWidget +plot_autocorrelograms = AutoCorrelogramsWidget +plot_crosscorrelograms = CrossCorrelogramsWidget From 1087ee1441360c526225003c2d87522d97a60f40 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 19 Jul 2023 17:21:57 +0200 Subject: [PATCH 104/166] pyproject-toml changed --- .github/workflows/full-test.yml | 4 ++++ pyproject.toml | 1 + 2 files changed, 5 insertions(+) diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index 3e8b082c50..633b226e57 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -66,6 +66,10 @@ jobs: id: modules-changed run: | for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + if [[ $file == *"pyproject.toml" ]]; then + echo "pyproject.toml changed" + echo "CORE_CHANGED=true" >> $GITHUB_OUTPUT + fi if [[ $file == *"/core/"* || $file == *"/extractors/neoextractors/neobaseextractor.py" ]]; then echo "Core changed" echo "CORE_CHANGED=true" >> $GITHUB_OUTPUT diff --git a/pyproject.toml b/pyproject.toml index bcdd76a3df..baeca9c959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ "Operating System :: OS Independent" ] + dependencies = [ "numpy", "neo>=0.12.0", From 77e2c1fe5632f17df4504b94317248e6df284b80 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 17:33:45 +0200 Subject: [PATCH 105/166] refactor widget : SpikeLocationsWidget --- src/spikeinterface/widgets/spike_locations.py | 231 +++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 232 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index da5ad5b08c..d32c3c2f4c 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -36,7 +36,7 @@ class SpikeLocationsWidget(BaseWidget): If True, the axis is set to off. Default False (matplotlib backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -105,6 +105,233 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from matplotlib.lines import Line2D + + from probeinterface import ProbeGroup + from probeinterface.plotting import plot_probe + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + spike_locations = dp.spike_locations + + probegroup = ProbeGroup.from_dict(dp.probegroup_dict) + probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) + contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) + + for probe in probegroup.probes: + text_on_contact = None + if dp.with_channel_ids: + text_on_contact = dp.channel_ids + + poly_contact, poly_contour = plot_probe( + probe, + ax=self.ax, + contacts_colors="w", + contacts_kargs=contacts_kargs, + probe_shape_kwargs=probe_shape_kwargs, + text_on_contact=text_on_contact, + ) + poly_contact.set_zorder(2) + if poly_contour is not None: + poly_contour.set_zorder(1) + + self.ax.set_title("") + + if dp.plot_all_units: + unit_colors = {} + unit_ids = dp.all_unit_ids + for unit in dp.all_unit_ids: + if unit not in dp.unit_ids: + unit_colors[unit] = "gray" + else: + unit_colors[unit] = dp.unit_colors[unit] + else: + unit_ids = dp.unit_ids + unit_colors = dp.unit_colors + labels = dp.unit_ids + + for i, unit in enumerate(unit_ids): + locs = spike_locations[unit] + + zorder = 5 if unit in dp.unit_ids else 3 + self.ax.scatter(locs["x"], locs["y"], s=2, alpha=0.3, color=unit_colors[unit], zorder=zorder) + + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in dp.unit_ids + ] + if dp.plot_legend: + # if self.legend is not None: + if hasattr(self, 'legend') and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + # set proper axis limits + xlims, ylims = estimate_axis_lims(spike_locations) + + ax_xlims = list(self.ax.get_xlim()) + ax_ylims = list(self.ax.get_ylim()) + + ax_xlims[0] = xlims[0] if xlims[0] < ax_xlims[0] else ax_xlims[0] + ax_xlims[1] = xlims[1] if xlims[1] > ax_xlims[1] else ax_xlims[1] + ax_ylims[0] = ylims[0] if ylims[0] < ax_ylims[0] else ax_ylims[0] + ax_ylims[1] = ylims[1] if ylims[1] > ax_ylims[1] else ax_ylims[1] + + self.ax.set_xlim(ax_xlims) + self.ax.set_ylim(ax_ylims) + if dp.hide_axis: + self.ax.axis("off") + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + fig, self.ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], + list(data_plot["unit_colors"].keys()), + ratios[0] * width_cm, + height_cm, + ) + + self.controller = unit_controller + + # mpl_plotter = MplSpikeLocationsPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=fig.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["plot_all_units"] = True + data_plot["plot_legend"] = True + data_plot["hide_axis"] = True + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + fig = self.ax.get_figure() + fig.canvas.draw() + fig.canvas.flush_events() + + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + spike_locations = dp.spike_locations + + # ensure serializable for sortingview + # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + xlims, ylims = estimate_axis_lims(spike_locations) + + unit_items = [] + for unit in unit_ids: + spike_times_sec = dp.sorting.get_unit_spike_train( + unit_id=unit, segment_index=dp.segment_index, return_times=True + ) + unit_items.append( + vv.SpikeLocationsItem( + unit_id=unit, + spike_times_sec=spike_times_sec.astype("float32"), + x_locations=spike_locations[unit]["x"].astype("float32"), + y_locations=spike_locations[unit]["y"].astype("float32"), + ) + ) + + v_spike_locations = vv.SpikeLocations( + units=unit_items, + hide_unit_selector=dp.hide_unit_selector, + x_range=xlims.astype("float32"), + y_range=ylims.astype("float32"), + channel_locations=locations, + disable_auto_rotate=True, + ) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[ + vv.LayoutItem(v_units_table, max_size=150), + vv.LayoutItem(v_spike_locations), + ], + ) + else: + self.view = v_spike_locations + + # self.set_view(view) + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + def estimate_axis_lims(spike_locations, quantile=0.02): # set proper axis limits diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index fb3a611c60..2a146b52b9 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -24,7 +24,7 @@ # units on probe from .unit_locations import UnitLocationsWidget -# from .spike_locations import SpikeLocationsWidget +from .spike_locations import SpikeLocationsWidget # unit presence @@ -62,7 +62,7 @@ AutoCorrelogramsWidget, # CrossCorrelogramsWidget, # QualityMetricsWidget, - # SpikeLocationsWidget, + SpikeLocationsWidget, # SpikesOnTracesWidget, # TemplateMetricsWidget, # MotionWidget, @@ -134,4 +134,5 @@ plot_unit_locations = UnitLocationsWidget plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_spike_locations = SpikeLocationsWidget From 0a3c1834b7dee41ee1a0ab409c7c8b812e943d60 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 19:36:50 +0200 Subject: [PATCH 106/166] tridesclous 1.6.8 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bcdd76a3df..7c767dd6e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,7 +129,7 @@ test = [ "datalad==0.16.2", ## install tridesclous for testing ## - "tridesclous>=1.6.7", + "tridesclous>=1.6.8", ## sliding_nn "pymde", From 1bdb64f5e0d0a8dda32460efc92a6cd92b6c3e21 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 20:52:03 +0200 Subject: [PATCH 107/166] widget refactor TemplateMetricsWidget QualityMetricsWidget --- src/spikeinterface/widgets/metrics.py | 211 +++++++++++++++++- src/spikeinterface/widgets/quality_metrics.py | 2 +- .../widgets/template_metrics.py | 2 +- src/spikeinterface/widgets/widget_list.py | 10 +- 4 files changed, 217 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 8e77e4a0f0..207e3a8a6c 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -1,8 +1,9 @@ import warnings import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors +from ..core.core_tools import check_json class MetricsBaseWidget(BaseWidget): @@ -29,7 +30,7 @@ class MetricsBaseWidget(BaseWidget): If True, metrics data are included in unit table, by default True """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -77,3 +78,209 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + metrics = dp.metrics + num_metrics = len(metrics.columns) + + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["num_axes"] = num_metrics ** 2 + backend_kwargs["ncols"] = num_metrics + + all_unit_ids = metrics.index.values + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + assert self.axes.ndim == 2 + + if dp.unit_ids is None: + colors = ["gray"] * len(all_unit_ids) + else: + colors = [] + for unit in all_unit_ids: + color = "gray" if unit not in dp.unit_ids else dp.unit_colors[unit] + colors.append(color) + + self.patches = [] + for i, m1 in enumerate(metrics.columns): + for j, m2 in enumerate(metrics.columns): + if i == j: + self.axes[i, j].hist(metrics[m1], color="gray") + else: + p = self.axes[i, j].scatter(metrics[m1], metrics[m2], c=colors, s=3, marker="o") + self.patches.append(p) + if i == num_metrics - 1: + self.axes[i, j].set_xlabel(m2, fontsize=10) + if j == 0: + self.axes[i, j].set_ylabel(m1, fontsize=10) + self.axes[i, j].set_xticklabels([]) + self.axes[i, j].set_yticklabels([]) + self.axes[i, j].spines["top"].set_visible(False) + self.axes[i, j].spines["right"].set_visible(False) + + self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) + + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + if data_plot["unit_ids"] is None: + data_plot["unit_ids"] = [] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm + ) + + self.controller = unit_controller + + # mpl_plotter = MplMetricsPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + from matplotlib.lines import Line2D + + unit_ids = self.controller["unit_ids"].value + + unit_colors = self.data_plot["unit_colors"] + # matplotlib next_data_plot dict update at each call + all_units = list(unit_colors.keys()) + colors = [] + sizes = [] + for unit in all_units: + color = "gray" if unit not in unit_ids else unit_colors[unit] + size = 1 if unit not in unit_ids else 5 + colors.append(color) + sizes.append(size) + + # here we do a trick: we just update colors + # if hasattr(self.mpl_plotter, "patches"): + if hasattr(self, "patches"): + # for p in self.mpl_plotter.patches: + for p in self.patches: + p.set_color(colors) + p.set_sizes(sizes) + else: + backend_kwargs = {} + backend_kwargs["figure"] = self.figure + # self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) + self.plot_matplotlib(self.data_plot, **backend_kwargs) + + if len(unit_ids) > 0: + for l in self.figure.legends: + l.remove() + handles = [ + Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) + for unit in unit_ids + ] + labels = unit_ids + self.figure.legend( + handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + metrics = dp.metrics + metric_names = list(metrics.columns) + + if dp.unit_ids is None: + unit_ids = metrics.index.values + else: + unit_ids = dp.unit_ids + # unit_ids = self.make_serializable(unit_ids) + unit_ids = make_serializable(unit_ids) + + metrics_sv = [] + for col in metric_names: + dtype = metrics.iloc[0][col].dtype + metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) + metrics_sv.append(metric) + + units_m = [] + for unit_id in unit_ids: + values = check_json(metrics.loc[unit_id].to_dict()) + values_skip_nans = {} + for k, v in values.items(): + if np.isnan(v): + continue + values_skip_nans[k] = v + + units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans)) + v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv) + + if not dp.hide_unit_selector: + if dp.include_metrics_data: + # make a view of the sorting to add tmp properties + sorting_copy = dp.sorting.select_units(unit_ids=dp.sorting.unit_ids) + for col in metric_names: + if col not in sorting_copy.get_property_keys(): + sorting_copy.set_property(col, metrics[col].values) + # generate table with properties + v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names) + else: + v_units_table = generate_unit_table_view(dp.sorting) + + self.view = vv.Splitter( + direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics) + ) + else: + self.view = v_metrics + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) \ No newline at end of file diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index f1c2ad6e23..46bcd6c07b 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -23,7 +23,7 @@ class QualityMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - possible_backends = {} + # possible_backends = {} def __init__( self, diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index b441882730..7361757666 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -22,7 +22,7 @@ class TemplateMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - possible_backends = {} + # possible_backends = {} def __init__( self, diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 2a146b52b9..e9e2b179b0 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -38,8 +38,8 @@ from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget # metrics -# from .quality_metrics import QualityMetricsWidget -# from .template_metrics import TemplateMetricsWidget +from .quality_metrics import QualityMetricsWidget +from .template_metrics import TemplateMetricsWidget # motion/drift @@ -61,10 +61,10 @@ AllAmplitudesDistributionsWidget, AutoCorrelogramsWidget, # CrossCorrelogramsWidget, - # QualityMetricsWidget, + QualityMetricsWidget, SpikeLocationsWidget, # SpikesOnTracesWidget, - # TemplateMetricsWidget, + TemplateMetricsWidget, # MotionWidget, # TemplateSimilarityWidget, # TimeseriesWidget, @@ -135,4 +135,6 @@ plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_spike_locations = SpikeLocationsWidget +plot_template_metrics = TemplateMetricsWidget +plot_quality_metrics = QualityMetricsWidget From 5394263962e8f2e6370881af62e22817677be0ce Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:01:42 +0200 Subject: [PATCH 108/166] widget refactor MotionWidget --- src/spikeinterface/widgets/motion.py | 128 +++++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 7 +- 2 files changed, 130 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 82e9be2407..48aba8de47 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -36,7 +36,7 @@ class MotionWidget(BaseWidget): The alpha of the scatter points, default 0.5 """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -68,3 +68,127 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from matplotlib.colors import Normalize + + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + + assert backend_kwargs["axes"] is None + assert backend_kwargs["ax"] is None + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + fig = self.figure + fig.clear() + + is_rigid = dp.motion.shape[1] == 1 + + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) + ax0 = fig.add_subplot(gs[0, 0]) + ax1 = fig.add_subplot(gs[0, 1]) + ax2 = fig.add_subplot(gs[1, 0]) + if not is_rigid: + ax3 = fig.add_subplot(gs[1, 1]) + ax1.sharex(ax0) + ax1.sharey(ax0) + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(dp.motion)) * 1.05 + else: + motion_lim = dp.motion_lim + + if dp.times is None: + temporal_bins_plot = dp.temporal_bins + x = dp.peaks["sample_index"] / dp.sampling_frequency + else: + # use real times and adjust temporal bins with t_start + temporal_bins_plot = dp.temporal_bins + dp.times[0] + x = dp.times[dp.peaks["sample_index"]] + + corrected_location = correct_motion_on_peaks( + dp.peaks, + dp.peak_locations, + dp.sampling_frequency, + dp.motion, + dp.temporal_bins, + dp.spatial_bins, + direction="y", + ) + + y = dp.peak_locations["y"] + y2 = corrected_location["y"] + if dp.scatter_decimate is not None: + x = x[:: dp.scatter_decimate] + y = y[:: dp.scatter_decimate] + y2 = y2[:: dp.scatter_decimate] + + if dp.color_amplitude: + amps = dp.peaks["amplitude"] + amps_abs = np.abs(amps) + q_95 = np.quantile(amps_abs, 0.95) + if dp.scatter_decimate is not None: + amps = amps[:: dp.scatter_decimate] + amps_abs = amps_abs[:: dp.scatter_decimate] + cmap = plt.get_cmap(dp.amplitude_cmap) + if dp.amplitude_clim is None: + amps = amps_abs + amps /= q_95 + c = cmap(amps) + else: + norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) + c = cmap(norm_function(amps)) + color_kwargs = dict( + color=None, + c=c, + alpha=dp.amplitude_alpha, + ) + else: + color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) + + ax0.scatter(x, y, s=1, **color_kwargs) + if dp.depth_lim is not None: + ax0.set_ylim(*dp.depth_lim) + ax0.set_title("Peak depth") + ax0.set_xlabel("Times [s]") + ax0.set_ylabel("Depth [um]") + + ax1.scatter(x, y2, s=1, **color_kwargs) + ax1.set_xlabel("Times [s]") + ax1.set_ylabel("Depth [um]") + ax1.set_title("Corrected peak depth") + + ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") + ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.set_ylim(-motion_lim, motion_lim) + ax2.set_ylabel("Motion [um]") + ax2.set_title("Motion vectors") + axes = [ax0, ax1, ax2] + + if not is_rigid: + im = ax3.imshow( + dp.motion.T, + aspect="auto", + origin="lower", + extent=( + temporal_bins_plot[0], + temporal_bins_plot[-1], + dp.spatial_bins[0], + dp.spatial_bins[-1], + ), + ) + im.set_clim(-motion_lim, motion_lim) + cbar = fig.colorbar(im) + cbar.ax.set_xlabel("motion [um]") + ax3.set_xlabel("Times [s]") + ax3.set_ylabel("Depth [um]") + ax3.set_title("Motion vectors") + axes.append(ax3) + self.axes = np.array(axes) \ No newline at end of file diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index e9e2b179b0..897965b4eb 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -43,7 +43,7 @@ # motion/drift -# from .motion import MotionWidget +from .motion import MotionWidget # similarity # from .template_similarity import TemplateSimilarityWidget @@ -60,12 +60,12 @@ AmplitudesWidget, AllAmplitudesDistributionsWidget, AutoCorrelogramsWidget, - # CrossCorrelogramsWidget, + CrossCorrelogramsWidget, QualityMetricsWidget, SpikeLocationsWidget, # SpikesOnTracesWidget, TemplateMetricsWidget, - # MotionWidget, + MotionWidget, # TemplateSimilarityWidget, # TimeseriesWidget, UnitLocationsWidget, @@ -137,4 +137,5 @@ plot_spike_locations = SpikeLocationsWidget plot_template_metrics = TemplateMetricsWidget plot_quality_metrics = QualityMetricsWidget +plot_motion = MotionWidget From ddf0d8d3e417acd652c1784b9cc0092f49fb4670 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:06:36 +0200 Subject: [PATCH 109/166] refactor widget : TemplateSimilarityWidget --- .../widgets/template_similarity.py | 56 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 475c873c29..93b9a49f49 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -1,9 +1,8 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor -from ..core.basesorting import BaseSorting class TemplateSimilarityWidget(BaseWidget): @@ -27,7 +26,7 @@ class TemplateSimilarityWidget(BaseWidget): If True, color bar is displayed, default True. """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -63,3 +62,54 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + im = self.ax.matshow(dp.similarity, cmap=dp.cmap) + + if dp.show_unit_ticks: + # Major ticks + self.ax.set_xticks(np.arange(0, len(dp.unit_ids))) + self.ax.set_yticks(np.arange(0, len(dp.unit_ids))) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + self.ax.set_yticklabels(dp.unit_ids, fontsize=12) + self.ax.set_xticklabels(dp.unit_ids, fontsize=12) + if dp.show_colorbar: + self.figure.colorbar(im) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + # ensure serializable for sortingview + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + # similarity + ss_items = [] + for i1, u1 in enumerate(unit_ids): + for i2, u2 in enumerate(unit_ids): + ss_items.append( + vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=dp.similarity[i1, i2].astype("float32")) + ) + + self.view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 897965b4eb..e2366920e5 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -46,7 +46,7 @@ from .motion import MotionWidget # similarity -# from .template_similarity import TemplateSimilarityWidget +from .template_similarity import TemplateSimilarityWidget # from .unit_depths import UnitDepthsWidget @@ -66,7 +66,7 @@ # SpikesOnTracesWidget, TemplateMetricsWidget, MotionWidget, - # TemplateSimilarityWidget, + TemplateSimilarityWidget, # TimeseriesWidget, UnitLocationsWidget, # UnitTemplatesWidget, @@ -138,4 +138,5 @@ plot_template_metrics = TemplateMetricsWidget plot_quality_metrics = QualityMetricsWidget plot_motion = MotionWidget +plot_template_similarity = TemplateSimilarityWidget From 9890419db8fb8487d04fae471438002f67070a40 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:30:21 +0200 Subject: [PATCH 110/166] refactor widgets UnitTemplatesWidget UnitWaveformsWidget --- src/spikeinterface/widgets/unit_templates.py | 53 +++- src/spikeinterface/widgets/unit_waveforms.py | 250 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 10 +- 3 files changed, 305 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 41c4ece09c..84856d2df4 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,12 +1,61 @@ from .unit_waveforms import UnitWaveformsWidget - +from .base import to_attr class UnitTemplatesWidget(UnitWaveformsWidget): - possible_backends = {} + # possible_backends = {} def __init__(self, *args, **kargs): kargs["plot_waveforms"] = False UnitWaveformsWidget.__init__(self, *args, **kargs) + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # ensure serializable for sortingview + unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids + unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices + + # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) + unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) + + templates_dict = {} + for u_i, unit in enumerate(unit_ids): + templates_dict[unit] = {} + templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] + templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] + + aw_items = [ + vv.AverageWaveformItem( + unit_id=u, + channel_ids=list(unit_id_to_channel_ids[u]), + waveform=t["mean"].astype("float32"), + waveform_std_dev=t["std"].astype("float32"), + ) + for u, t in templates_dict.items() + ] + + locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} + v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) + + if not dp.hide_unit_selector: + v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) + + self.view = vv.Box( + direction="horizontal", + items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_average_waveforms)], + ) + else: + self.view = v_average_waveforms + + # self.handle_display_and_url(view, **backend_kwargs) + # return view + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index ba707a8221..49c75bf046 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core import ChannelSparsity @@ -59,7 +59,7 @@ class UnitWaveformsWidget(BaseWidget): Display legend, default True """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -165,6 +165,252 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + from probeinterface.plotting import plot_probe + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + dp = to_attr(data_plot) + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + if backend_kwargs.get("axes", None) is not None: + assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" + elif backend_kwargs.get("ax", None) is not None: + assert dp.same_axis, "If 'same_axis' is not used, provide as many 'axes' as neurons" + else: + if dp.same_axis: + backend_kwargs["num_axes"] = 1 + backend_kwargs["ncols"] = None + else: + backend_kwargs["num_axes"] = len(dp.unit_ids) + backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + else: + ax = self.axes.flatten()[i] + color = dp.unit_colors[unit_id] + + chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] + xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() + + # plot waveforms + if dp.plot_waveforms: + wfs = dp.wfs_by_ids[unit_id] + if dp.unit_selected_waveforms is not None: + wfs = wfs[dp.unit_selected_waveforms[unit_id]] + elif dp.max_spikes_per_unit is not None: + if len(wfs) > dp.max_spikes_per_unit: + random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] + wfs = wfs[random_idxs] + wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T + + if dp.x_offset_units: + # 0.7 is to match spacing in xvect + xvec = xvectors_flat + i * 0.7 * dp.delta_x + else: + xvec = xvectors_flat + + ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color) + + if not dp.plot_templates: + ax.get_lines()[-1].set_label(f"{unit_id}") + + # plot template + if dp.plot_templates: + template = dp.templates[i, :, :][:, chan_inds] * dp.y_scale + dp.y_offset[:, chan_inds] + + if dp.x_offset_units: + # 0.7 is to match spacing in xvect + xvec = xvectors_flat + i * 0.7 * dp.delta_x + else: + xvec = xvectors_flat + + ax.plot( + xvec, template.T.flatten(), lw=dp.lw_templates, alpha=dp.alpha_templates, color=color, label=unit_id + ) + + template_label = dp.unit_ids[i] + if dp.set_title: + ax.set_title(f"template {template_label}") + + # plot channels + if dp.plot_channels: + # TODO enhance this + ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") + + if dp.same_axis and dp.plot_legend: + # if self.legend is not None: + if hasattr(self, 'legend') and self.legend is not None: + self.legend.remove() + self.legend = self.figure.legend( + loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True + ) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + self.we = we = data_plot["waveform_extractor"] + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.1, 0.7, 0.2] + + with plt.ioff(): + output1 = widgets.Output() + with output1: + self.fig_wf = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + output2 = widgets.Output() + with output2: + self.fig_probe, self.ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) + plt.show() + + data_plot["unit_ids"] = data_plot["unit_ids"][:1] + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + same_axis_button = widgets.Checkbox( + value=False, + description="same axis", + disabled=False, + ) + + plot_templates_button = widgets.Checkbox( + value=True, + description="plot templates", + disabled=False, + ) + + hide_axis_button = widgets.Checkbox( + value=True, + description="hide axis", + disabled=False, + ) + + footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) + + self.controller = { + "same_axis": same_axis_button, + "plot_templates": plot_templates_button, + "hide_axis": hide_axis_button, + } + self.controller.update(unit_controller) + + # mpl_plotter = MplUnitWaveformPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + + self.widget = widgets.AppLayout( + center=self.fig_wf.canvas, + left_sidebar=unit_widget, + right_sidebar=self.fig_probe.canvas, + pane_widths=ratios, + footer=footer, + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + self.fig_wf.clear() + self.ax_probe.clear() + + unit_ids = self.controller["unit_ids"].value + same_axis = self.controller["same_axis"].value + plot_templates = self.controller["plot_templates"].value + hide_axis = self.controller["hide_axis"].value + + # matplotlib next_data_plot dict update at each call + data_plot = self.next_data_plot + data_plot["unit_ids"] = unit_ids + data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) + data_plot["template_stds"] = self.we.get_all_templates(unit_ids=unit_ids, mode="std") + data_plot["same_axis"] = same_axis + data_plot["plot_templates"] = plot_templates + if data_plot["plot_waveforms"]: + data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} + + backend_kwargs = {} + + if same_axis: + backend_kwargs["ax"] = self.fig_wf.add_subplot() + data_plot["set_title"] = False + else: + backend_kwargs["figure"] = self.fig_wf + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + if same_axis: + # self.mpl_plotter.ax.axis("equal") + self.ax.axis("equal") + if hide_axis: + # self.mpl_plotter.ax.axis("off") + self.ax.axis("off") + else: + if hide_axis: + for i in range(len(unit_ids)): + # ax = self.mpl_plotter.axes.flatten()[i] + ax = self.axes.flatten()[i] + ax.axis("off") + + # update probe plot + channel_locations = self.we.get_channel_locations() + self.ax_probe.plot( + channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 + ) + self.ax_probe.axis("off") + self.ax_probe.axis("equal") + + for unit in unit_ids: + channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] + self.ax_probe.plot( + channel_locations[channel_inds, 0], + channel_locations[channel_inds, 1], + ls="", + marker="o", + markersize=3, + color=self.next_data_plot["unit_colors"][unit], + ) + self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) + fig_probe = self.ax_probe.get_figure() + + self.fig_wf.canvas.draw() + self.fig_wf.canvas.flush_events() + fig_probe.canvas.draw() + fig_probe.canvas.flush_events() + def get_waveforms_scales(we, templates, channel_locations, x_offset_units=False): """ diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index e2366920e5..cb19eda93c 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -5,8 +5,8 @@ # from .timeseries import TimeseriesWidget # waveform -# from .unit_waveforms import UnitWaveformsWidget -# from .unit_templates import UnitTemplatesWidget +from .unit_waveforms import UnitWaveformsWidget +from .unit_templates import UnitTemplatesWidget # from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg @@ -69,8 +69,8 @@ TemplateSimilarityWidget, # TimeseriesWidget, UnitLocationsWidget, - # UnitTemplatesWidget, - # UnitWaveformsWidget, + UnitTemplatesWidget, + UnitWaveformsWidget, # UnitWaveformDensityMapWidget, # UnitDepthsWidget, # summary @@ -139,4 +139,6 @@ plot_quality_metrics = QualityMetricsWidget plot_motion = MotionWidget plot_template_similarity = TemplateSimilarityWidget +plot_unit_templates = UnitTemplatesWidget +plot_unit_waveforms = UnitWaveformsWidget From f064513b1631697b7db83197d8113852edd592e8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 21:36:39 +0200 Subject: [PATCH 111/166] widget refactor : UnitWaveformDensityMapWidget --- .../widgets/unit_waveforms_density_map.py | 76 ++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9f3e5e86b5..9216373d87 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from ..core import ChannelSparsity, get_template_extremum_channel @@ -33,7 +33,7 @@ class UnitWaveformDensityMapWidget(BaseWidget): all channel per units, default False """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -156,3 +156,75 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + else: + if dp.same_axis: + num_axes = 1 + else: + num_axes = len(dp.unit_ids) + backend_kwargs["ncols"] = 1 + backend_kwargs["num_axes"] = num_axes + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.same_axis: + ax = self.ax + hist2d = dp.all_hist2d + im = ax.imshow( + hist2d.T, + interpolation="nearest", + origin="lower", + aspect="auto", + extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + cmap="hot", + ) + else: + for unit_index, unit_id in enumerate(dp.unit_ids): + hist2d = dp.all_hist2d[unit_id] + ax = self.axes.flatten()[unit_index] + im = ax.imshow( + hist2d.T, + interpolation="nearest", + origin="lower", + aspect="auto", + extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + cmap="hot", + ) + + for unit_index, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + else: + ax = self.axes.flatten()[unit_index] + color = dp.unit_colors[unit_id] + ax.plot(dp.templates_flat[unit_id], color=color, lw=1) + + # final cosmetics + for unit_index, unit_id in enumerate(dp.unit_ids): + if dp.same_axis: + ax = self.ax + if unit_index != 0: + continue + else: + ax = self.axes.flatten()[unit_index] + chan_inds = dp.channel_inds[unit_id] + for i, chan_ind in enumerate(chan_inds): + if i != 0: + ax.axvline(i * dp.template_width, color="w", lw=3) + channel_id = dp.channel_ids[chan_ind] + x = i * dp.template_width + dp.template_width // 2 + y = (dp.bin_max + dp.bin_min) / 2.0 + ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") + + ax.set_xticks([]) + ax.set_ylabel(f"unit_id {unit_id}") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index cb19eda93c..68034ee27e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -7,7 +7,7 @@ # waveform from .unit_waveforms import UnitWaveformsWidget from .unit_templates import UnitTemplatesWidget -# from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +from .unit_waveforms_density_map import UnitWaveformDensityMapWidget # isi/ccg/acg from .autocorrelograms import AutoCorrelogramsWidget @@ -71,7 +71,7 @@ UnitLocationsWidget, UnitTemplatesWidget, UnitWaveformsWidget, - # UnitWaveformDensityMapWidget, + UnitWaveformDensityMapWidget, # UnitDepthsWidget, # summary # UnitSummaryWidget, @@ -141,4 +141,5 @@ plot_template_similarity = TemplateSimilarityWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms = UnitWaveformsWidget +plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget From d9307ab24a96dad6dbfd2a72b6f68615a79a1d15 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 22:32:29 +0200 Subject: [PATCH 112/166] refactor widget : UnitDepthsWidget --- src/spikeinterface/widgets/unit_depths.py | 23 +++++++++++++++++++++-- src/spikeinterface/widgets/widget_list.py | 5 +++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 5ceee0c133..9b710815e4 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -1,7 +1,7 @@ import numpy as np from warnings import warn -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -24,7 +24,7 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes, default 'neg' """ - possible_backends = {} + # possible_backends = {} def __init__( self, waveform_extractor, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs @@ -56,3 +56,22 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + size = dp.num_spikes / max(dp.num_spikes) * 120 + ax.scatter(dp.unit_amplitudes, dp.unit_depths, color=dp.colors, s=size) + + ax.set_aspect(3) + ax.set_xlabel("amplitude") + ax.set_ylabel("depth [um]") + ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 68034ee27e..4ded22305e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -49,7 +49,7 @@ from .template_similarity import TemplateSimilarityWidget -# from .unit_depths import UnitDepthsWidget +from .unit_depths import UnitDepthsWidget # summary # from .unit_summary import UnitSummaryWidget @@ -72,7 +72,7 @@ UnitTemplatesWidget, UnitWaveformsWidget, UnitWaveformDensityMapWidget, - # UnitDepthsWidget, + UnitDepthsWidget, # summary # UnitSummaryWidget, # SortingSummaryWidget, @@ -142,4 +142,5 @@ plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms = UnitWaveformsWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_unit_depths = UnitDepthsWidget From 8da772269f8ff85244431f7ca16d41172eec27f7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 22:55:33 +0200 Subject: [PATCH 113/166] refactor widget : UnitSummaryWidget --- src/spikeinterface/widgets/unit_summary.py | 189 ++++++++++++++++----- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 150 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8e1ffe2637..68fa8b77d2 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,7 +1,7 @@ import numpy as np from typing import Union -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -31,7 +31,7 @@ class UnitSummaryWidget(BaseWidget): If WaveformExtractor is already sparse, the argument is ignored """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -48,55 +48,160 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(we.sorting) - if we.is_extension("unit_locations"): - plot_data_unit_locations = UnitLocationsWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False - ).plot_data - unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") - unit_location = unit_locations[unit_id] - else: - plot_data_unit_locations = None - unit_location = None + # if we.is_extension("unit_locations"): + # plot_data_unit_locations = UnitLocationsWidget( + # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False + # ).plot_data + # unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") + # unit_location = unit_locations[unit_id] + # else: + # plot_data_unit_locations = None + # unit_location = None + + # plot_data_waveforms = UnitWaveformsWidget( + # we, + # unit_ids=[unit_id], + # unit_colors=unit_colors, + # plot_templates=True, + # same_axis=True, + # plot_legend=False, + # sparsity=sparsity, + # ).plot_data + + # plot_data_waveform_density = UnitWaveformDensityMapWidget( + # we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False + # ).plot_data + + # if we.is_extension("correlograms"): + # plot_data_acc = AutoCorrelogramsWidget( + # we, + # unit_ids=[unit_id], + # unit_colors=unit_colors, + # ).plot_data + # else: + # plot_data_acc = None + + # use other widget to plot data + # if we.is_extension("spike_amplitudes"): + # plot_data_amplitudes = AmplitudesWidget( + # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True + # ).plot_data + # else: + # plot_data_amplitudes = None - plot_data_waveforms = UnitWaveformsWidget( - we, - unit_ids=[unit_id], + plot_data = dict( + we=we, + unit_id=unit_id, unit_colors=unit_colors, - plot_templates=True, - same_axis=True, - plot_legend=False, sparsity=sparsity, - ).plot_data + # unit_location=unit_location, + # plot_data_unit_locations=plot_data_unit_locations, + # plot_data_waveforms=plot_data_waveforms, + # plot_data_waveform_density=plot_data_waveform_density, + # plot_data_acc=plot_data_acc, + # plot_data_amplitudes=plot_data_amplitudes, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - plot_data_waveform_density = UnitWaveformDensityMapWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False - ).plot_data + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + + unit_id = dp.unit_id + we = dp.we + unit_colors = dp.unit_colors + sparsity = dp.sparsity + + + # force the figure without axes + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (18, 7) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + backend_kwargs["num_axes"] = 0 + backend_kwargs["ax"] = None + backend_kwargs["axes"] = None + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + # and use custum grid spec + fig = self.figure + nrows = 2 + ncols = 3 + # if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: + if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + ncols += 1 + # if dp.plot_data_amplitudes is not None : + if we.is_extension("spike_amplitudes"): + + nrows += 1 + gs = fig.add_gridspec(nrows, ncols) + # if dp.plot_data_unit_locations is not None: + if we.is_extension("unit_locations"): + ax1 = fig.add_subplot(gs[:2, 0]) + # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) + w = UnitLocationsWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, + backend='matplotlib', ax=ax1) + + unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") + unit_location = unit_locations[unit_id] + # x, y = dp.unit_location[0], dp.unit_location[1] + x, y = unit_location[0], unit_location[1] + ax1.set_xlim(x - 80, x + 80) + ax1.set_ylim(y - 250, y + 250) + ax1.set_xticks([]) + ax1.set_xlabel(None) + ax1.set_ylabel(None) + + ax2 = fig.add_subplot(gs[:2, 1]) + # UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) + w = UnitWaveformsWidget( + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_templates=True, + same_axis=True, + plot_legend=False, + sparsity=sparsity, + backend='matplotlib', ax=ax2) + + ax2.set_title(None) + + ax3 = fig.add_subplot(gs[:2, 2]) + # UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) + UnitWaveformDensityMapWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, same_axis=False, + backend='matplotlib', ax=ax3) + ax3.set_ylabel(None) + + # if dp.plot_data_acc is not None: if we.is_extension("correlograms"): - plot_data_acc = AutoCorrelogramsWidget( + ax4 = fig.add_subplot(gs[:2, 3]) + # AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) + AutoCorrelogramsWidget( we, unit_ids=[unit_id], unit_colors=unit_colors, - ).plot_data - else: - plot_data_acc = None + backend='matplotlib', ax=ax4, + ) - # use other widget to plot data - if we.is_extension("spike_amplitudes"): - plot_data_amplitudes = AmplitudesWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True - ).plot_data - else: - plot_data_amplitudes = None - plot_data = dict( - unit_id=unit_id, - unit_location=unit_location, - plot_data_unit_locations=plot_data_unit_locations, - plot_data_waveforms=plot_data_waveforms, - plot_data_waveform_density=plot_data_waveform_density, - plot_data_acc=plot_data_acc, - plot_data_amplitudes=plot_data_amplitudes, - ) + ax4.set_title(None) + ax4.set_yticks([]) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + # if dp.plot_data_amplitudes is not None: + if we.is_extension("spike_amplitudes"): + ax5 = fig.add_subplot(gs[2, :3]) + ax6 = fig.add_subplot(gs[2, 3]) + axes = np.array([ax5, ax6]) + # AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) + AmplitudesWidget( + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True, + backend='matplotlib', axes=axes) + + fig.suptitle(f"unit_id: {dp.unit_id}") diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 4ded22305e..5820477dc8 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -52,7 +52,7 @@ from .unit_depths import UnitDepthsWidget # summary -# from .unit_summary import UnitSummaryWidget +from .unit_summary import UnitSummaryWidget # from .sorting_summary import SortingSummaryWidget @@ -74,7 +74,7 @@ UnitWaveformDensityMapWidget, UnitDepthsWidget, # summary - # UnitSummaryWidget, + UnitSummaryWidget, # SortingSummaryWidget, ] @@ -143,4 +143,5 @@ plot_unit_waveforms = UnitWaveformsWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_depths = UnitDepthsWidget +plot_unit_summary = UnitSummaryWidget From fa49471061712fee17d9475a1947d3f2d3e6d607 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 19 Jul 2023 23:17:43 +0200 Subject: [PATCH 114/166] refactor widget : SortingSummaryWidget --- src/spikeinterface/widgets/sorting_summary.py | 135 +++++++++++++++--- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 122 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 8f50eb1dde..bdf692888f 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget, define_widget_function_from_class +from .base import BaseWidget, to_attr from .amplitudes import AmplitudesWidget from .crosscorrelograms import CrossCorrelogramsWidget @@ -34,7 +34,7 @@ class SortingSummaryWidget(BaseWidget): (sortingview backend) """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -56,27 +56,130 @@ def __init__( unit_ids = sorting.get_unit_ids() # use other widgets to generate data (except for similarity) - template_plot_data = UnitTemplatesWidget( - we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True - ).plot_data - ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - amps_plot_data = AmplitudesWidget( - we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True - ).plot_data - locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data + # template_plot_data = UnitTemplatesWidget( + # we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True + # ).plot_data + # ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data + # amps_plot_data = AmplitudesWidget( + # we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True + # ).plot_data + # locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data + # sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data plot_data = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, - templates=template_plot_data, - correlograms=ccg_plot_data, - amplitudes=amps_plot_data, - similarity=sim_plot_data, - unit_locations=locs_plot_data, + sparsity=sparsity, + # templates=template_plot_data, + # correlograms=ccg_plot_data, + # amplitudes=amps_plot_data, + # similarity=sim_plot_data, + # unit_locations=locs_plot_data, unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, + + max_amplitudes_per_unit=max_amplitudes_per_unit, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + dp = to_attr(data_plot) + we = dp.waveform_extractor + unit_ids = dp.unit_ids + sparsity = dp.sparsity + + + # unit_ids = self.make_serializable(dp.unit_ids) + unit_ids = make_serializable(dp.unit_ids) + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # amplitudes_plotter = AmplitudesPlotter() + # v_spike_amplitudes = amplitudes_plotter.do_plot( + # dp.amplitudes, generate_url=False, display=False, backend="sortingview" + # ) + # template_plotter = UnitTemplatesPlotter() + # v_average_waveforms = template_plotter.do_plot( + # dp.templates, generate_url=False, display=False, backend="sortingview" + # ) + # xcorrelograms_plotter = CrossCorrelogramsPlotter() + # v_cross_correlograms = xcorrelograms_plotter.do_plot( + # dp.correlograms, generate_url=False, display=False, backend="sortingview" + # ) + # unitlocation_plotter = UnitLocationsPlotter() + # v_unit_locations = unitlocation_plotter.do_plot( + # dp.unit_locations, generate_url=False, display=False, backend="sortingview" + # ) + + v_spike_amplitudes = AmplitudesWidget( + we, unit_ids=unit_ids, max_spikes_per_unit=dp.max_amplitudes_per_unit, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview" + ).view + v_average_waveforms = UnitTemplatesWidget( + we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview" + ).view + v_cross_correlograms = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview").view + + v_unit_locations = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, + generate_url=False, display=False, backend="sortingview").view + + w = TemplateSimilarityWidget(we, unit_ids=unit_ids, immediate_plot=False, + generate_url=False, display=False, backend="sortingview" ) + similarity = w.data_plot["similarity"] + print(similarity.shape) + + # similarity + similarity_scores = [] + for i1, u1 in enumerate(unit_ids): + for i2, u2 in enumerate(unit_ids): + similarity_scores.append( + vv.UnitSimilarityScore( + unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32") + ) + ) + + # unit ids + v_units_table = generate_unit_table_view( + dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores + ) + + if dp.curation: + v_curation = vv.SortingCuration2(label_choices=dp.label_choices) + v1 = vv.Splitter(direction="vertical", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_curation)) + else: + v1 = v_units_table + v2 = vv.Splitter( + direction="horizontal", + item1=vv.LayoutItem(v_unit_locations, stretch=0.2), + item2=vv.LayoutItem( + vv.Splitter( + direction="horizontal", + item1=vv.LayoutItem(v_average_waveforms), + item2=vv.LayoutItem( + vv.Splitter( + direction="vertical", + item1=vv.LayoutItem(v_spike_amplitudes), + item2=vv.LayoutItem(v_cross_correlograms), + ) + ), + ) + ), + ) + + # assemble layout + # v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) + self.view = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) + + # self.handle_display_and_url(v_summary, **backend_kwargs) + # return v_summary + + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 5820477dc8..ae0b898035 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -53,7 +53,7 @@ # summary from .unit_summary import UnitSummaryWidget -# from .sorting_summary import SortingSummaryWidget +from .sorting_summary import SortingSummaryWidget widget_list = [ @@ -75,7 +75,7 @@ UnitDepthsWidget, # summary UnitSummaryWidget, - # SortingSummaryWidget, + SortingSummaryWidget, ] @@ -144,4 +144,5 @@ plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_depths = UnitDepthsWidget plot_unit_summary = UnitSummaryWidget +plot_sorting_summary = SortingSummaryWidget From 9f9587cf1375155e6ff45b98def20b54ad656b8d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 08:43:16 +0200 Subject: [PATCH 115/166] refactor widgets : TimeseriesWidget --- src/spikeinterface/widgets/timeseries.py | 342 +++++++++++++++++++++- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 342 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 93e0358460..0e82c85b94 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -1,8 +1,10 @@ +import warnings + import numpy as np from ..core import BaseRecording, order_channels_by_depth -from .base import BaseWidget -from .utils import get_some_colors +from .base import BaseWidget, to_attr +from .utils import get_some_colors, array_to_image class TimeseriesWidget(BaseWidget): @@ -56,7 +58,7 @@ class TimeseriesWidget(BaseWidget): The output widget """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -213,6 +215,340 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + from .matplotlib_utils import make_mpl_figure + + dp = to_attr(data_plot) + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + ax = self.ax + n = len(dp.channel_ids) + if dp.channel_locations is not None: + y_locs = dp.channel_locations[:, 1] + else: + y_locs = np.arange(n) + min_y = np.min(y_locs) + max_y = np.max(y_locs) + + if dp.mode == "line": + offset = dp.vspacing * (n - 1) + + for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + for i, chan_id in enumerate(dp.channel_ids): + offset = dp.vspacing * i + color = dp.colors[layer_key][chan_id] + ax.plot(dp.times, offset + traces[:, i], color=color) + ax.get_lines()[-1].set_label(layer_key) + + if dp.show_channel_ids: + ax.set_yticks(np.arange(n) * dp.vspacing) + channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) + ax.set_yticklabels(channel_labels) + else: + ax.get_yaxis().set_visible(False) + + ax.set_xlim(*dp.time_range) + ax.set_ylim(-dp.vspacing, dp.vspacing * n) + ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) + ax.set_xlabel("time (s)") + if dp.add_legend: + ax.legend(loc="upper right") + + elif dp.mode == "map": + assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' + assert len(dp.clims) == 1 + clim = list(dp.clims.values())[0] + extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) + im = ax.imshow( + dp.list_traces[0].T, interpolation="nearest", origin="lower", aspect="auto", extent=extent, cmap=dp.cmap + ) + + im.set_clim(*clim) + + if dp.with_colorbar: + self.figure.colorbar(im, ax=ax) + + if dp.show_channel_ids: + ax.set_yticks(np.linspace(min_y, max_y, n) + (max_y - min_y) / n * 0.5) + channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) + ax.set_yticklabels(channel_labels) + else: + ax.get_yaxis().set_visible(False) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + recordings = data_plot["recordings"] + + # first layer + rec0 = recordings[data_plot["layer_keys"][0]] + + cm = 1 / 2.54 + + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + ratios = [0.1, 0.8, 0.2] + + with plt.ioff(): + output = widgets.Output() + with output: + self.figure, self.ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) + plt.show() + + t_start = 0.0 + t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() + + ts_widget, ts_controller = make_timeseries_controller( + t_start, + t_stop, + data_plot["layer_keys"], + rec0.get_num_segments(), + data_plot["time_range"], + data_plot["mode"], + False, + width_cm, + ) + + ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) + + scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) + + self.controller = ts_controller + self.controller.update(ch_controller) + self.controller.update(scale_controller) + + # mpl_plotter = MplTimeseriesPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) + # for w in self.controller.values(): + # if isinstance(w, widgets.Button): + # w.on_click(self.updater) + # else: + # w.observe(self.updater) + + self.recordings = data_plot["recordings"] + self.return_scaled = data_plot["return_scaled"] + self.list_traces = None + self.actual_segment_index = self.controller["segment_index"].value + + self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] + self.t_stops = [ + self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() + for seg_index in range(self.rec0.get_num_segments()) + ] + + for w in self.controller.values(): + if isinstance(w, widgets.Button): + w.on_click(self._update_ipywidget) + else: + w.observe(self._update_ipywidget) + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + footer=ts_widget, + left_sidebar=scale_widget, + right_sidebar=ch_widget, + pane_heights=[0, 6, 1], + pane_widths=ratios, + ) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + import ipywidgets.widgets as widgets + + # if changing the layer_key, no need to retrieve and process traces + retrieve_traces = True + scale_up = False + scale_down = False + if change is not None: + for cname, c in self.controller.items(): + if isinstance(change, dict): + if change["owner"] is c and cname == "layer_key": + retrieve_traces = False + elif isinstance(change, widgets.Button): + if change is c and cname == "plus": + scale_up = True + if change is c and cname == "minus": + scale_down = True + + t_start = self.controller["t_start"].value + window = self.controller["window"].value + layer_key = self.controller["layer_key"].value + segment_index = self.controller["segment_index"].value + mode = self.controller["mode"].value + chan_start, chan_stop = self.controller["channel_inds"].value + + if mode == "line": + self.controller["all_layers"].layout.visibility = "visible" + all_layers = self.controller["all_layers"].value + elif mode == "map": + self.controller["all_layers"].layout.visibility = "hidden" + all_layers = False + + if all_layers: + self.controller["layer_key"].layout.visibility = "hidden" + else: + self.controller["layer_key"].layout.visibility = "visible" + + if chan_start == chan_stop: + chan_stop += 1 + channel_indices = np.arange(chan_start, chan_stop) + + t_stop = self.t_stops[segment_index] + if self.actual_segment_index != segment_index: + # change time_slider limits + self.controller["t_start"].max = t_stop + self.actual_segment_index = segment_index + + # protect limits + if t_start >= t_stop - window: + t_start = t_stop - window + + time_range = np.array([t_start, t_start + window]) + data_plot = self.next_data_plot + + if retrieve_traces: + all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids + if self.data_plot["order"] is not None: + all_channel_ids = all_channel_ids[self.data_plot["order"]] + channel_ids = all_channel_ids[channel_indices] + if self.data_plot["order_channel_by_depth"]: + order, _ = order_channels_by_depth(self.rec0, channel_ids) + else: + order = None + times, list_traces, frame_range, channel_ids = _get_trace_list( + self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled + ) + self.list_traces = list_traces + else: + times = data_plot["times"] + list_traces = data_plot["list_traces"] + frame_range = data_plot["frame_range"] + channel_ids = data_plot["channel_ids"] + + if all_layers: + layer_keys = self.data_plot["layer_keys"] + recordings = self.recordings + list_traces_plot = self.list_traces + else: + layer_keys = [layer_key] + recordings = {layer_key: self.recordings[layer_key]} + list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] + + if scale_up: + if mode == "line": + data_plot["vspacing"] *= 0.8 + elif mode == "map": + data_plot["clims"] = { + layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() + } + if scale_down: + if mode == "line": + data_plot["vspacing"] *= 1.2 + elif mode == "map": + data_plot["clims"] = { + layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() + } + + self.next_data_plot["vspacing"] = data_plot["vspacing"] + self.next_data_plot["clims"] = data_plot["clims"] + + if mode == "line": + clims = None + elif mode == "map": + clims = {layer_key: self.data_plot["clims"][layer_key]} + + # matplotlib next_data_plot dict update at each call + data_plot["mode"] = mode + data_plot["frame_range"] = frame_range + data_plot["time_range"] = time_range + data_plot["with_colorbar"] = False + data_plot["recordings"] = recordings + data_plot["layer_keys"] = layer_keys + data_plot["list_traces"] = list_traces_plot + data_plot["times"] = times + data_plot["clims"] = clims + data_plot["channel_ids"] = channel_ids + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + + fig = self.ax.figure + fig.canvas.draw() + fig.canvas.flush_events() + + + def plot_sortingview(self, data_plot, **backend_kwargs): + import sortingview.views as vv + from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + + try: + import pyvips + except ImportError: + raise ImportError("To use the timeseries in sorting view you need the pyvips package.") + + backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + dp = to_attr(data_plot) + + assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' + + if not dp.order_channel_by_depth: + warnings.warn( + "It is recommended to set 'order_channel_by_depth' to True " "when using the sortingview backend" + ) + + tiled_layers = [] + for layer_key, traces in zip(dp.layer_keys, dp.list_traces): + img = array_to_image( + traces, + clim=dp.clims[layer_key], + num_timepoints_per_row=dp.num_timepoints_per_row, + colormap=dp.cmap, + scalebar=True, + sampling_frequency=dp.recordings[layer_key].get_sampling_frequency(), + ) + + tiled_layers.append(vv.TiledImageLayer(layer_key, img)) + + # view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) + self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) + + # self.set_view(view_ts) + + # timeseries currently doesn't display on the jupyter backend + backend_kwargs["display"] = False + # self.handle_display_and_url(view_ts, **backend_kwargs) + # return view_ts + + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + + + + + + def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): # function also used in ipywidgets plotter diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ae0b898035..11fe0b0e92 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,7 +2,7 @@ from .base import backend_kwargs_desc # basics -# from .timeseries import TimeseriesWidget +from .timeseries import TimeseriesWidget # waveform from .unit_waveforms import UnitWaveformsWidget @@ -67,7 +67,7 @@ TemplateMetricsWidget, MotionWidget, TemplateSimilarityWidget, - # TimeseriesWidget, + TimeseriesWidget, UnitLocationsWidget, UnitTemplatesWidget, UnitWaveformsWidget, @@ -136,6 +136,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_spike_locations = SpikeLocationsWidget plot_template_metrics = TemplateMetricsWidget +plot_timeseries = TimeseriesWidget plot_quality_metrics = QualityMetricsWidget plot_motion = MotionWidget plot_template_similarity = TemplateSimilarityWidget From 97410f9dda2133b97609e858684974d39360fa76 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 09:22:45 +0200 Subject: [PATCH 116/166] refactor widget : SpikesOnTracesWidget --- .../widgets/spikes_on_traces.py | 280 ++++++++++++++++-- src/spikeinterface/widgets/widget_list.py | 5 +- 2 files changed, 260 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index b50896df4d..9deb346387 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -1,6 +1,6 @@ import numpy as np -from .base import BaseWidget +from .base import BaseWidget, to_attr from .utils import get_unit_colors from .timeseries import TimeseriesWidget from ..core import ChannelSparsity @@ -60,7 +60,7 @@ class SpikesOnTracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 """ - possible_backends = {} + # possible_backends = {} def __init__( self, @@ -86,28 +86,28 @@ def __init__( **backend_kwargs, ): we = waveform_extractor - recording: BaseRecording = we.recording + # recording: BaseRecording = we.recording sorting: BaseSorting = we.sorting - ts_widget = TimeseriesWidget( - recording, - segment_index, - channel_ids, - order_channel_by_depth, - time_range, - mode, - return_scaled, - cmap, - show_channel_ids, - color_groups, - color, - clim, - tile_size, - seconds_per_row, - with_colorbar, - backend, - **backend_kwargs, - ) + # ts_widget = TimeseriesWidget( + # recording, + # segment_index, + # channel_ids, + # order_channel_by_depth, + # time_range, + # mode, + # return_scaled, + # cmap, + # show_channel_ids, + # color_groups, + # color, + # clim, + # tile_size, + # seconds_per_row, + # with_colorbar, + # backend, + # **backend_kwargs, + # ) if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -133,9 +133,26 @@ def __init__( # get templates unit_locations = compute_unit_locations(we, outputs="by_unit") + options = dict( + segment_index=segment_index, + channel_ids=channel_ids, + order_channel_by_depth=order_channel_by_depth, + time_range=time_range, + mode=mode, + return_scaled=return_scaled, + cmap=cmap, + show_channel_ids=show_channel_ids, + color_groups=color_groups, + color=color, + clim=clim, + tile_size=tile_size, + with_colorbar=with_colorbar, + ) + plot_data = dict( - timeseries=ts_widget.plot_data, + # timeseries=ts_widget.plot_data, waveform_extractor=waveform_extractor, + options=options, unit_ids=unit_ids, sparsity=sparsity, unit_colors=unit_colors, @@ -143,3 +160,220 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .matplotlib_utils import make_mpl_figure + + from matplotlib.patches import Ellipse + from matplotlib.lines import Line2D + + dp = to_attr(data_plot) + we = dp.waveform_extractor + recording = we.recording + sorting = we.sorting + + + + # first plot time series + # tsplotter = TimeseriesPlotter() + # data_plot["timeseries"]["add_legend"] = False + # tsplotter.do_plot(dp.timeseries, **backend_kwargs) + # self.ax = tsplotter.ax + # self.axes = tsplotter.axes + # self.figure = tsplotter.figure + + # first plot time series + ts_widget = TimeseriesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + self.ax = ts_widget.ax + self.axes = ts_widget.axes + self.figure = ts_widget.figure + + + ax = self.ax + + # we = dp.waveform_extractor + # sorting = dp.waveform_extractor.sorting + # frame_range = dp.timeseries["frame_range"] + # segment_index = dp.timeseries["segment_index"] + # min_y = np.min(dp.timeseries["channel_locations"][:, 1]) + # max_y = np.max(dp.timeseries["channel_locations"][:, 1]) + + frame_range = ts_widget.data_plot["frame_range"] + segment_index = ts_widget.data_plot["segment_index"] + min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) + max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) + + + # n = len(dp.timeseries["channel_ids"]) + # order = dp.timeseries["order"] + n = len(ts_widget.data_plot["channel_ids"]) + order = ts_widget.data_plot["order"] + + if order is None: + order = np.arange(n) + + if ax.get_legend() is not None: + ax.get_legend().remove() + + # loop through units and plot a scatter of spikes at estimated location + handles = [] + labels = [] + + for unit in dp.unit_ids: + spike_frames = sorting.get_unit_spike_train(unit, segment_index=segment_index) + spike_start, spike_end = np.searchsorted(spike_frames, frame_range) + + chan_ids = dp.sparsity.unit_id_to_channel_ids[unit] + + spike_frames_to_plot = spike_frames[spike_start:spike_end] + + # if dp.timeseries["mode"] == "map": + if dp.options["mode"] == "map": + spike_times_to_plot = sorting.get_unit_spike_train( + unit, segment_index=segment_index, return_times=True + )[spike_start:spike_end] + unit_y_loc = min_y + max_y - dp.unit_locations[unit][1] + # markers = np.ones_like(spike_frames_to_plot) * (min_y + max_y - dp.unit_locations[unit][1]) + width = 2 * 1e-3 + ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2) + patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot] + for p in patches: + ax.add_patch(p) + handles.append( + Line2D( + [0], + [0], + ls="", + marker="o", + markersize=5, + markeredgewidth=2, + markeredgecolor=dp.unit_colors[unit], + markerfacecolor="none", + ) + ) + labels.append(unit) + else: + # construct waveforms + label_set = False + if len(spike_frames_to_plot) > 0: + # vspacing = dp.timeseries["vspacing"] + # traces = dp.timeseries["list_traces"][0] + vspacing = ts_widget.data_plot["vspacing"] + traces = ts_widget.data_plot["list_traces"][0] + + waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] + # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) + waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) + + # times = dp.timeseries["times"][waveform_idxs] + times = ts_widget.data_plot["times"][waveform_idxs] + + # discontinuity + times[:, -1] = np.nan + times_r = times.reshape(times.shape[0] * times.shape[1]) + waveforms = traces[waveform_idxs] # [:, :, order] + waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) + + # for i, chan_id in enumerate(dp.timeseries["channel_ids"]): + for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): + offset = vspacing * i + if chan_id in chan_ids: + l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) + if not label_set: + handles.append(l[0]) + labels.append(unit) + label_set = True + ax.legend(handles, labels) + + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + dp = to_attr(data_plot) + we = dp.waveform_extractor + + + ratios = [0.2, 0.8] + # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) + + backend_kwargs_ts = backend_kwargs.copy() + backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] + backend_kwargs_ts["display"] = False + height_cm = backend_kwargs["height_cm"] + width_cm = backend_kwargs["width_cm"] + + # plot timeseries + # tsplotter = TimeseriesPlotter() + # data_plot["timeseries"]["add_legend"] = False + # tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) + + # ts_w = tsplotter.widget + # ts_updater = tsplotter.updater + + ts_widget = TimeseriesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + self.ax = ts_widget.ax + self.axes = ts_widget.axes + self.figure = ts_widget.figure + + + # we = data_plot["waveform_extractor"] + + unit_widget, unit_controller = make_unit_controller( + data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm + ) + + self.controller = dict() + # self.controller = ts_updater.controller + self.controller.update(ts_widget.controller) + self.controller.update(unit_controller) + + # mpl_plotter = MplSpikesOnTracesPlotter() + + # self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) + # for w in self.controller.values(): + # w.observe(self.updater) + + for w in self.controller.values(): + w.observe(self._update_ipywidget) + + + self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) + + # a first update + # self.updater(None) + self._update_ipywidget(None) + + if backend_kwargs["display"]: + # self.check_backend() + display(self.widget) + + def _update_ipywidget(self, change): + self.ax.clear() + + unit_ids = self.controller["unit_ids"].value + + # update ts + # self.ts_updater.__call__(change) + + # update data plot + # data_plot = self.data_plot.copy() + data_plot = self.next_data_plot + # data_plot["timeseries"] = self.ts_updater.next_data_plot + data_plot["unit_ids"] = unit_ids + + backend_kwargs = {} + backend_kwargs["ax"] = self.ax + + # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 11fe0b0e92..db73dbc5ec 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -18,7 +18,7 @@ # drift/motion # spikes-traces -# from .spikes_on_traces import SpikesOnTracesWidget +from .spikes_on_traces import SpikesOnTracesWidget # PC related @@ -63,7 +63,7 @@ CrossCorrelogramsWidget, QualityMetricsWidget, SpikeLocationsWidget, - # SpikesOnTracesWidget, + SpikesOnTracesWidget, TemplateMetricsWidget, MotionWidget, TemplateSimilarityWidget, @@ -135,6 +135,7 @@ plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_spike_locations = SpikeLocationsWidget +plot_spikes_on_traces = SpikesOnTracesWidget plot_template_metrics = TemplateMetricsWidget plot_timeseries = TimeseriesWidget plot_quality_metrics = QualityMetricsWidget From d159145376368e8a48bc5340fd868d17a95eff3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 07:27:31 +0000 Subject: [PATCH 117/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../widgets/all_amplitudes_distributions.py | 3 +- src/spikeinterface/widgets/amplitudes.py | 14 ++--- .../widgets/autocorrelograms.py | 4 +- src/spikeinterface/widgets/base.py | 32 +++++------ .../widgets/crosscorrelograms.py | 4 +- .../widgets/ipywidgets_utils.py | 3 +- .../widgets/matplotlib_utils.py | 4 +- src/spikeinterface/widgets/metrics.py | 10 ++-- src/spikeinterface/widgets/motion.py | 4 +- src/spikeinterface/widgets/sorting_summary.py | 47 +++++++++------- .../widgets/sortingview_utils.py | 11 ++-- src/spikeinterface/widgets/spike_locations.py | 11 ++-- .../widgets/spikes_on_traces.py | 16 ++---- .../widgets/template_similarity.py | 4 +- .../widgets/tests/test_widgets.py | 2 +- src/spikeinterface/widgets/timeseries.py | 13 +++-- src/spikeinterface/widgets/unit_depths.py | 1 - src/spikeinterface/widgets/unit_locations.py | 18 ++----- src/spikeinterface/widgets/unit_summary.py | 54 +++++++++++-------- src/spikeinterface/widgets/unit_templates.py | 5 +- src/spikeinterface/widgets/unit_waveforms.py | 9 ++-- src/spikeinterface/widgets/widget_list.py | 9 ++-- 22 files changed, 128 insertions(+), 150 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 18585a4f96..d3cca278c9 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -55,7 +55,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.patches import Ellipse from matplotlib.lines import Line2D - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) # self.make_mpl_figure(**backend_kwargs) @@ -85,4 +84,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if np.max(ylims) < 0: ax.set_ylim(min(ylims), 0) if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) \ No newline at end of file + ax.set_ylim(0, max(ylims)) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 7c76d26204..a2a3ccff3b 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -121,9 +121,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.patches import Ellipse from matplotlib.lines import Line2D - - - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -168,7 +165,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.plot_legend: # if self.legend is not None: - if hasattr(self, 'legend') and self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -186,7 +183,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import ipywidgets.widgets as widgets from IPython.display import display from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller - + check_ipywidget_backend() self.next_data_plot = data_plot.copy() @@ -232,7 +229,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.widget = widgets.AppLayout( # center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer - center=self.figure.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer + center=self.figure.canvas, + left_sidebar=unit_widget, + pane_widths=ratios + [0], + footer=footer, ) # a first update @@ -241,7 +241,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: # self.check_backend() - display(self.widget) + display(self.widget) def _update_ipywidget(self, change): # self.fig.clear() diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index f07246efa6..e7b5014367 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -41,7 +41,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) + # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) ac_items = [] @@ -63,6 +63,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7b0ba0454e..a1cc76eb19 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -19,7 +19,6 @@ def set_default_plotter_backend(backend): default_backend_ = backend - backend_kwargs_desc = { "matplotlib": { "figure": "Matplotlib figure. When None, it is created. Default None", @@ -29,33 +28,37 @@ def set_default_plotter_backend(backend): "figsize": "Size of matplotlib figure. Default None", "figtitle": "The figure title. Default None", }, - 'sortingview': { + "sortingview": { "generate_url": "If True, the figurl URL is generated and printed. Default True", "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", "figlabel": "The figurl figure label. Default None", "height": "The height of the sortingview View in jupyter. Default None", }, - "ipywidgets" : { - "width_cm": "Width of the figure in cm (default 10)", - "height_cm": "Height of the figure in cm (default 6)", - "display": "If True, widgets are immediately displayed", + "ipywidgets": { + "width_cm": "Width of the figure in cm (default 10)", + "height_cm": "Height of the figure in cm (default 6)", + "display": "If True, widgets are immediately displayed", }, - } default_backend_kwargs = { "matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None}, "sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None}, - "ipywidgets" : {"width_cm": 25, "height_cm": 10, "display": True}, + "ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True}, } - class BaseWidget: # this need to be reset in the subclass possible_backends = None - def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_kwargs, ): + def __init__( + self, + data_plot=None, + backend=None, + immediate_plot=True, + **backend_kwargs, + ): # every widgets must prepare a dict "plot_data" in the init self.data_plot = data_plot backend = self.check_backend(backend) @@ -70,16 +73,16 @@ def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_ ) backend_kwargs_ = default_backend_kwargs[self.backend].copy() backend_kwargs_.update(backend_kwargs) - + self.backend_kwargs = backend_kwargs_ if immediate_plot: - print('immediate_plot', self.backend, self.backend_kwargs) + print("immediate_plot", self.backend, self.backend_kwargs) self.do_plot(self.backend, **self.backend_kwargs) @classmethod def get_possible_backends(cls): - return [ k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}") ] + return [k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}")] def check_backend(self, backend): if backend is None: @@ -88,7 +91,6 @@ def check_backend(self, backend): f"{backend} backend not available! Available backends are: " f"{self.get_possible_backends()}" ) return backend - # def check_backend_kwargs(self, plotter, backend, **backend_kwargs): # plotter_kwargs = plotter.default_backend_kwargs @@ -102,7 +104,7 @@ def check_backend(self, backend): def do_plot(self, backend, **backend_kwargs): # backend = self.check_backend(backend) - func = getattr(self, f'plot_{backend}') + func = getattr(self, f"plot_{backend}") func(self.data_plot, **self.backend_kwargs) # @classmethod diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index eed76c3e04..4b83e61b69 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -124,9 +124,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ) ) - self.view = vv.CrossCorrelograms( - cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector - ) + self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector) # self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) # return v_cross_correlograms diff --git a/src/spikeinterface/widgets/ipywidgets_utils.py b/src/spikeinterface/widgets/ipywidgets_utils.py index 4490cc3365..a7c571d1f0 100644 --- a/src/spikeinterface/widgets/ipywidgets_utils.py +++ b/src/spikeinterface/widgets/ipywidgets_utils.py @@ -2,14 +2,13 @@ import numpy as np - def check_ipywidget_backend(): import matplotlib + mpl_backend = matplotlib.get_backend() assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" - def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): time_slider = widgets.FloatSlider( orientation="horizontal", diff --git a/src/spikeinterface/widgets/matplotlib_utils.py b/src/spikeinterface/widgets/matplotlib_utils.py index 6ccaaf5840..fb347552b1 100644 --- a/src/spikeinterface/widgets/matplotlib_utils.py +++ b/src/spikeinterface/widgets/matplotlib_utils.py @@ -65,11 +65,11 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figure.suptitle(figtitle) return figure, axes, ax - + # self.figure = figure # self.ax = ax # axes is always a 2D array of ax # self.axes = axes # if figtitle is not None: - # self.figure.suptitle(figtitle) \ No newline at end of file + # self.figure.suptitle(figtitle) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 207e3a8a6c..6551bb067e 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -91,7 +91,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = num_metrics ** 2 + backend_kwargs["num_axes"] = num_metrics**2 backend_kwargs["ncols"] = num_metrics all_unit_ids = metrics.index.values @@ -128,7 +128,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) - def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -169,7 +168,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for w in self.controller.values(): w.observe(self._update_ipywidget) - self.widget = widgets.AppLayout( center=self.figure.canvas, left_sidebar=unit_widget, @@ -203,7 +201,7 @@ def _update_ipywidget(self, change): # here we do a trick: we just update colors # if hasattr(self.mpl_plotter, "patches"): if hasattr(self, "patches"): - # for p in self.mpl_plotter.patches: + # for p in self.mpl_plotter.patches: for p in self.patches: p.set_color(colors) p.set_sizes(sizes) @@ -242,7 +240,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = metrics.index.values else: unit_ids = dp.unit_ids - # unit_ids = self.make_serializable(unit_ids) + # unit_ids = self.make_serializable(unit_ids) unit_ids = make_serializable(unit_ids) metrics_sv = [] @@ -283,4 +281,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(view, **backend_kwargs) # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) \ No newline at end of file + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 48aba8de47..1ebbb71743 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -76,10 +76,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - assert backend_kwargs["axes"] is None assert backend_kwargs["ax"] is None @@ -191,4 +189,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax3.set_ylabel("Depth [um]") ax3.set_title("Motion vectors") axes.append(ax3) - self.axes = np.array(axes) \ No newline at end of file + self.axes = np.array(axes) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index bdf692888f..5498df9a33 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -78,7 +78,6 @@ def __init__( unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, - max_amplitudes_per_unit=max_amplitudes_per_unit, ) @@ -93,7 +92,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = dp.unit_ids sparsity = dp.sparsity - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) @@ -117,21 +115,34 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # ) v_spike_amplitudes = AmplitudesWidget( - we, unit_ids=unit_ids, max_spikes_per_unit=dp.max_amplitudes_per_unit, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview" + we, + unit_ids=unit_ids, + max_spikes_per_unit=dp.max_amplitudes_per_unit, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", ).view v_average_waveforms = UnitTemplatesWidget( - we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview" + we, + unit_ids=unit_ids, + sparsity=sparsity, + hide_unit_selector=True, + generate_url=False, + display=False, + backend="sortingview", + ).view + v_cross_correlograms = CrossCorrelogramsWidget( + we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" ).view - v_cross_correlograms = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview").view - - v_unit_locations = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True, - generate_url=False, display=False, backend="sortingview").view - - w = TemplateSimilarityWidget(we, unit_ids=unit_ids, immediate_plot=False, - generate_url=False, display=False, backend="sortingview" ) + + v_unit_locations = UnitLocationsWidget( + we, unit_ids=unit_ids, hide_unit_selector=True, generate_url=False, display=False, backend="sortingview" + ).view + + w = TemplateSimilarityWidget( + we, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" + ) similarity = w.data_plot["similarity"] print(similarity.shape) @@ -140,9 +151,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): for i1, u1 in enumerate(unit_ids): for i2, u2 in enumerate(unit_ids): similarity_scores.append( - vv.UnitSimilarityScore( - unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32") - ) + vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=similarity[i1, i2].astype("float32")) ) # unit ids @@ -179,7 +188,5 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(v_summary, **backend_kwargs) # return v_summary - - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - + self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py index 90dfcb77a3..f5339b4bbb 100644 --- a/src/spikeinterface/widgets/sortingview_utils.py +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -3,8 +3,6 @@ from ..core.core_tools import check_json - - sortingview_backend_kwargs_desc = { "generate_url": "If True, the figurl URL is generated and printed. Default True", "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", @@ -14,7 +12,6 @@ sortingview_default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) @@ -25,6 +22,7 @@ def make_serializable(*args): returns = returns[0] return returns + def is_notebook() -> bool: try: shell = get_ipython().__class__.__name__ @@ -37,6 +35,7 @@ def is_notebook() -> bool: except NameError: return False + def handle_display_and_url(widget, view, **backend_kwargs): url = None if is_notebook() and backend_kwargs["display"]: @@ -44,14 +43,12 @@ def handle_display_and_url(widget, view, **backend_kwargs): if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: - # figlabel = widget.default_label + # figlabel = widget.default_label figlabel = "" url = view.url(label=figlabel) print(url) - - return url - + return url def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index d32c3c2f4c..06495409cf 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -111,7 +111,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D from probeinterface import ProbeGroup - from probeinterface.plotting import plot_probe + from probeinterface.plotting import plot_probe dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -169,7 +169,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ] if dp.plot_legend: # if self.legend is not None: - if hasattr(self, 'legend') and self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -245,13 +245,11 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # self.updater(None) self._update_ipywidget(None) - if backend_kwargs["display"]: # self.check_backend() display(self.widget) def _update_ipywidget(self, change): - self.ax.clear() unit_ids = self.controller["unit_ids"].value @@ -272,7 +270,6 @@ def _update_ipywidget(self, change): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url @@ -282,7 +279,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): spike_locations = dp.spike_locations # ensure serializable for sortingview - # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) + # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} @@ -331,8 +328,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - def estimate_axis_lims(spike_locations, quantile=0.02): # set proper axis limits all_locs = np.concatenate(list(spike_locations.values())) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 9deb346387..0aeb923f38 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -173,8 +173,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): recording = we.recording sorting = we.sorting - - # first plot time series # tsplotter = TimeseriesPlotter() # data_plot["timeseries"]["add_legend"] = False @@ -189,7 +187,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.axes = ts_widget.axes self.figure = ts_widget.figure - ax = self.ax # we = dp.waveform_extractor @@ -204,7 +201,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) - # n = len(dp.timeseries["channel_ids"]) # order = dp.timeseries["order"] n = len(ts_widget.data_plot["channel_ids"]) @@ -263,10 +259,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): traces = ts_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) + # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) - # times = dp.timeseries["times"][waveform_idxs] + # times = dp.timeseries["times"][waveform_idxs] times = ts_widget.data_plot["times"][waveform_idxs] # discontinuity @@ -286,7 +282,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label_set = True ax.legend(handles, labels) - def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -300,7 +295,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) we = dp.waveform_extractor - ratios = [0.2, 0.8] # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -323,15 +317,14 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.axes = ts_widget.axes self.figure = ts_widget.figure - # we = data_plot["waveform_extractor"] - + unit_widget, unit_controller = make_unit_controller( data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm ) self.controller = dict() - # self.controller = ts_updater.controller + # self.controller = ts_updater.controller self.controller.update(ts_widget.controller) self.controller.update(unit_controller) @@ -344,7 +337,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for w in self.controller.values(): w.observe(self._update_ipywidget) - self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) # a first update diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 93b9a49f49..a6e0356db1 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -62,7 +62,7 @@ def __init__( ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - + def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .matplotlib_utils import make_mpl_figure @@ -91,7 +91,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) @@ -112,4 +111,3 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(view, **backend_kwargs) # return view self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4ddec4134b..610da470e8 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,7 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -# from spikeinterface.widgets import HAVE_MPL, HAVE_SV +# from spikeinterface.widgets import HAVE_MPL, HAVE_SV import spikeinterface.extractors as se diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 0e82c85b94..86e886babc 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -284,7 +284,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_timeseries_controller, make_channel_controller, make_scale_controller + from .ipywidgets_utils import ( + check_ipywidget_backend, + make_timeseries_controller, + make_channel_controller, + make_scale_controller, + ) check_ipywidget_backend() @@ -499,7 +504,6 @@ def _update_ipywidget(self, change): fig.canvas.draw() fig.canvas.flush_events() - def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url @@ -545,11 +549,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - - - - def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): # function also used in ipywidgets plotter k0 = list(recordings.keys())[0] diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 9b710815e4..faf9198c0d 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -74,4 +74,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.set_xlabel("amplitude") ax.set_ylabel("depth [um]") ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) - diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 725a4c3023..9e35f7b32c 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -79,7 +79,7 @@ def __init__( plot_legend=plot_legend, hide_axis=hide_axis, ) - + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): @@ -90,17 +90,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.patches import Ellipse from matplotlib.lines import Line2D - - - dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - unit_locations = dp.unit_locations probegroup = ProbeGroup.from_dict(dp.probegroup_dict) @@ -161,8 +156,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ] if dp.plot_legend: - if hasattr(self, 'legend') and self.legend is not None: - # if self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: + # if self.legend is not None: self.legend.remove() self.legend = self.figure.legend( handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -171,9 +166,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.hide_axis: self.ax.axis("off") - - - def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -188,7 +180,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -227,7 +218,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: display(self.widget) - + def _update_ipywidget(self, change): self.ax.clear() @@ -283,4 +274,3 @@ def plot_sortingview(self, data_plot, **backend_kwargs): # self.handle_display_and_url(view, **backend_kwargs) self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 68fa8b77d2..66f522e3ca 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -109,13 +109,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .matplotlib_utils import make_mpl_figure dp = to_attr(data_plot) - + unit_id = dp.unit_id we = dp.we unit_colors = dp.unit_colors sparsity = dp.sparsity - # force the figure without axes if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (18, 7) @@ -136,7 +135,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ncols += 1 # if dp.plot_data_amplitudes is not None : if we.is_extension("spike_amplitudes"): - nrows += 1 gs = fig.add_gridspec(nrows, ncols) @@ -145,9 +143,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, - backend='matplotlib', ax=ax1) - + we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, backend="matplotlib", ax=ax1 + ) + unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] # x, y = dp.unit_location[0], dp.unit_location[1] @@ -161,22 +159,30 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax2 = fig.add_subplot(gs[:2, 1]) # UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) w = UnitWaveformsWidget( - we, - unit_ids=[unit_id], - unit_colors=unit_colors, - plot_templates=True, - same_axis=True, - plot_legend=False, - sparsity=sparsity, - backend='matplotlib', ax=ax2) - + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_templates=True, + same_axis=True, + plot_legend=False, + sparsity=sparsity, + backend="matplotlib", + ax=ax2, + ) + ax2.set_title(None) ax3 = fig.add_subplot(gs[:2, 2]) # UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) UnitWaveformDensityMapWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, same_axis=False, - backend='matplotlib', ax=ax3) + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + use_max_channel=True, + same_axis=False, + backend="matplotlib", + ax=ax3, + ) ax3.set_ylabel(None) # if dp.plot_data_acc is not None: @@ -187,10 +193,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): we, unit_ids=[unit_id], unit_colors=unit_colors, - backend='matplotlib', ax=ax4, + backend="matplotlib", + ax=ax4, ) - ax4.set_title(None) ax4.set_yticks([]) @@ -201,7 +207,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): axes = np.array([ax5, ax6]) # AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) AmplitudesWidget( - we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True, - backend='matplotlib', axes=axes) + we, + unit_ids=[unit_id], + unit_colors=unit_colors, + plot_legend=False, + plot_histograms=True, + backend="matplotlib", + axes=axes, + ) fig.suptitle(f"unit_id: {dp.unit_id}") diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 84856d2df4..04b26e300f 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,5 +1,6 @@ from .unit_waveforms import UnitWaveformsWidget -from .base import to_attr +from .base import to_attr + class UnitTemplatesWidget(UnitWaveformsWidget): # possible_backends = {} @@ -56,6 +57,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) - - UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 49c75bf046..833f13881d 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -250,7 +250,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.same_axis and dp.plot_legend: # if self.legend is not None: - if hasattr(self, 'legend') and self.legend is not None: + if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True @@ -326,7 +326,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): for w in self.controller.values(): w.observe(self._update_ipywidget) - self.widget = widgets.AppLayout( center=self.fig_wf.canvas, left_sidebar=unit_widget, @@ -342,7 +341,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): if backend_kwargs["display"]: # self.check_backend() display(self.widget) - + def _update_ipywidget(self, change): self.fig_wf.clear() self.ax_probe.clear() @@ -373,10 +372,10 @@ def _update_ipywidget(self, change): # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) if same_axis: - # self.mpl_plotter.ax.axis("equal") + # self.mpl_plotter.ax.axis("equal") self.ax.axis("equal") if hide_axis: - # self.mpl_plotter.ax.axis("off") + # self.mpl_plotter.ax.axis("off") self.ax.axis("off") else: if hide_axis: diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index db73dbc5ec..a753c78d4a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,4 +1,4 @@ -# from .base import define_widget_function_from_class +# from .base import define_widget_function_from_class from .base import backend_kwargs_desc # basics @@ -90,12 +90,12 @@ **backend_kwargs: kwargs {backend_kwargs} """ - # backend_str = f" {list(wcls.possible_backends.keys())}" + # backend_str = f" {list(wcls.possible_backends.keys())}" backend_str = f" {wcls.get_possible_backends()}" backend_kwargs_str = "" - # for backend, backend_plotter in wcls.possible_backends.items(): + # for backend, backend_plotter in wcls.possible_backends.items(): for backend in wcls.get_possible_backends(): - # backend_kwargs_desc = backend_plotter.backend_kwargs_desc + # backend_kwargs_desc = backend_plotter.backend_kwargs_desc kwargs_desc = backend_kwargs_desc[backend] if len(kwargs_desc) > 0: backend_kwargs_str += f"\n {backend}:\n\n" @@ -147,4 +147,3 @@ plot_unit_depths = UnitDepthsWidget plot_unit_summary = UnitSummaryWidget plot_sorting_summary = SortingSummaryWidget - From 7f189ff609367ca889f017a9f9a0f6eb6be0aeeb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 10:19:14 +0200 Subject: [PATCH 118/166] Allow order_channel_by_depth to accept dimentsions as list --- src/spikeinterface/core/recording_tools.py | 6 +++--- src/spikeinterface/preprocessing/depth_order.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 865e5cc283..e5901d7ee0 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -312,9 +312,9 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): The input recording channel_ids : list/array or None If given, a subset of channels to order locations for - dimensions : str or tuple + dimensions : str, tuple, or list If str, it needs to be 'x', 'y', 'z'. - If tuple, it sorts the locations in two dimensions using lexsort. + If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') Returns @@ -334,7 +334,7 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y")): assert dim < ndim, "Invalid dimensions!" order_f = np.argsort(locations[:, dim], kind="stable") else: - assert isinstance(dimensions, tuple), "dimensions can be a str or a tuple" + assert isinstance(dimensions, (tuple, list)), "dimensions can be str, tuple, or list" locations_to_sort = () for dim in dimensions: dim = ["x", "y", "z"].index(dim) diff --git a/src/spikeinterface/preprocessing/depth_order.py b/src/spikeinterface/preprocessing/depth_order.py index 944b8d1f75..0b8d8a730b 100644 --- a/src/spikeinterface/preprocessing/depth_order.py +++ b/src/spikeinterface/preprocessing/depth_order.py @@ -14,9 +14,9 @@ class DepthOrderRecording(ChannelSliceRecording): The recording to re-order. channel_ids : list/array or None If given, a subset of channels to order locations for - dimensions : str or tuple + dimensions : str, tuple, list If str, it needs to be 'x', 'y', 'z'. - If tuple, it sorts the locations in two dimensions using lexsort. + If tuple or list, it sorts the locations in two dimensions using lexsort. This approach is recommended since there is less ambiguity, by default ('x', 'y') """ From 9768010ab2721c6814ca0aa00d395f00b9b4d84c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 10:30:02 +0200 Subject: [PATCH 119/166] wip --- src/spikeinterface/widgets/base.py | 8 ++++---- src/spikeinterface/widgets/sortingview_utils.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7b0ba0454e..219787d87a 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -74,8 +74,8 @@ def __init__(self, data_plot=None, backend=None, immediate_plot=True, **backend_ self.backend_kwargs = backend_kwargs_ if immediate_plot: - print('immediate_plot', self.backend, self.backend_kwargs) - self.do_plot(self.backend, **self.backend_kwargs) + # print('immediate_plot', self.backend, self.backend_kwargs) + self.do_plot() @classmethod def get_possible_backends(cls): @@ -99,10 +99,10 @@ def check_backend(self, backend): # f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" # ) - def do_plot(self, backend, **backend_kwargs): + def do_plot(self): # backend = self.check_backend(backend) - func = getattr(self, f'plot_{backend}') + func = getattr(self, f'plot_{self.backend}') func(self.data_plot, **self.backend_kwargs) # @classmethod diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/sortingview_utils.py index 90dfcb77a3..c513c1f2b6 100644 --- a/src/spikeinterface/widgets/sortingview_utils.py +++ b/src/spikeinterface/widgets/sortingview_utils.py @@ -39,8 +39,9 @@ def is_notebook() -> bool: def handle_display_and_url(widget, view, **backend_kwargs): url = None - if is_notebook() and backend_kwargs["display"]: - display(view.jupyter(height=backend_kwargs["height"])) + # TODO: put this back when figurl-jupyter is working again + # if is_notebook() and backend_kwargs["display"]: + # display(view.jupyter(height=backend_kwargs["height"])) if backend_kwargs["generate_url"]: figlabel = backend_kwargs.get("figlabel") if figlabel is None: From 6a7d337b91d0a70de91ae5efc814bbee8f1a80de Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 10:39:43 +0200 Subject: [PATCH 120/166] remove old backend folder (matplotlib, ipywidgets, sortingview) not needed anymore --- .../widgets/ipywidgets/__init__.py | 9 - .../widgets/ipywidgets/amplitudes.py | 99 -------- .../widgets/ipywidgets/base_ipywidgets.py | 20 -- .../widgets/ipywidgets/metrics.py | 108 -------- .../widgets/ipywidgets/quality_metrics.py | 9 - .../widgets/ipywidgets/spike_locations.py | 97 -------- .../widgets/ipywidgets/spikes_on_traces.py | 145 ----------- .../widgets/ipywidgets/template_metrics.py | 9 - .../widgets/ipywidgets/timeseries.py | 232 ------------------ .../widgets/ipywidgets/unit_locations.py | 91 ------- .../widgets/ipywidgets/unit_templates.py | 11 - .../widgets/ipywidgets/unit_waveforms.py | 169 ------------- .../widgets/ipywidgets/utils.py | 97 -------- .../widgets/matplotlib/__init__.py | 17 -- .../all_amplitudes_distributions.py | 41 ---- .../widgets/matplotlib/amplitudes.py | 69 ------ .../widgets/matplotlib/autocorrelograms.py | 30 --- .../widgets/matplotlib/base_mpl.py | 102 -------- .../widgets/matplotlib/crosscorrelograms.py | 39 --- .../widgets/matplotlib/metrics.py | 50 ---- .../widgets/matplotlib/motion.py | 129 ---------- .../widgets/matplotlib/quality_metrics.py | 9 - .../widgets/matplotlib/spike_locations.py | 96 -------- .../widgets/matplotlib/spikes_on_traces.py | 104 -------- .../widgets/matplotlib/template_metrics.py | 9 - .../widgets/matplotlib/template_similarity.py | 30 --- .../widgets/matplotlib/timeseries.py | 70 ------ .../widgets/matplotlib/unit_depths.py | 22 -- .../widgets/matplotlib/unit_locations.py | 95 ------- .../widgets/matplotlib/unit_summary.py | 73 ------ .../widgets/matplotlib/unit_templates.py | 9 - .../widgets/matplotlib/unit_waveforms.py | 95 ------- .../matplotlib/unit_waveforms_density_map.py | 77 ------ .../widgets/sortingview/__init__.py | 11 - .../widgets/sortingview/amplitudes.py | 36 --- .../widgets/sortingview/autocorrelograms.py | 34 --- .../widgets/sortingview/base_sortingview.py | 103 -------- .../widgets/sortingview/crosscorrelograms.py | 37 --- .../widgets/sortingview/metrics.py | 61 ----- .../widgets/sortingview/quality_metrics.py | 11 - .../widgets/sortingview/sorting_summary.py | 86 ------- .../widgets/sortingview/spike_locations.py | 64 ----- .../widgets/sortingview/template_metrics.py | 11 - .../sortingview/template_similarity.py | 32 --- .../widgets/sortingview/timeseries.py | 54 ---- .../widgets/sortingview/unit_locations.py | 44 ---- .../widgets/sortingview/unit_templates.py | 54 ---- 47 files changed, 2900 deletions(-) delete mode 100644 src/spikeinterface/widgets/ipywidgets/__init__.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/amplitudes.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/metrics.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/quality_metrics.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/spike_locations.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/template_metrics.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/timeseries.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/unit_locations.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/unit_templates.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/unit_waveforms.py delete mode 100644 src/spikeinterface/widgets/ipywidgets/utils.py delete mode 100644 src/spikeinterface/widgets/matplotlib/__init__.py delete mode 100644 src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py delete mode 100644 src/spikeinterface/widgets/matplotlib/amplitudes.py delete mode 100644 src/spikeinterface/widgets/matplotlib/autocorrelograms.py delete mode 100644 src/spikeinterface/widgets/matplotlib/base_mpl.py delete mode 100644 src/spikeinterface/widgets/matplotlib/crosscorrelograms.py delete mode 100644 src/spikeinterface/widgets/matplotlib/metrics.py delete mode 100644 src/spikeinterface/widgets/matplotlib/motion.py delete mode 100644 src/spikeinterface/widgets/matplotlib/quality_metrics.py delete mode 100644 src/spikeinterface/widgets/matplotlib/spike_locations.py delete mode 100644 src/spikeinterface/widgets/matplotlib/spikes_on_traces.py delete mode 100644 src/spikeinterface/widgets/matplotlib/template_metrics.py delete mode 100644 src/spikeinterface/widgets/matplotlib/template_similarity.py delete mode 100644 src/spikeinterface/widgets/matplotlib/timeseries.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_depths.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_locations.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_summary.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_templates.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_waveforms.py delete mode 100644 src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py delete mode 100644 src/spikeinterface/widgets/sortingview/__init__.py delete mode 100644 src/spikeinterface/widgets/sortingview/amplitudes.py delete mode 100644 src/spikeinterface/widgets/sortingview/autocorrelograms.py delete mode 100644 src/spikeinterface/widgets/sortingview/base_sortingview.py delete mode 100644 src/spikeinterface/widgets/sortingview/crosscorrelograms.py delete mode 100644 src/spikeinterface/widgets/sortingview/metrics.py delete mode 100644 src/spikeinterface/widgets/sortingview/quality_metrics.py delete mode 100644 src/spikeinterface/widgets/sortingview/sorting_summary.py delete mode 100644 src/spikeinterface/widgets/sortingview/spike_locations.py delete mode 100644 src/spikeinterface/widgets/sortingview/template_metrics.py delete mode 100644 src/spikeinterface/widgets/sortingview/template_similarity.py delete mode 100644 src/spikeinterface/widgets/sortingview/timeseries.py delete mode 100644 src/spikeinterface/widgets/sortingview/unit_locations.py delete mode 100644 src/spikeinterface/widgets/sortingview/unit_templates.py diff --git a/src/spikeinterface/widgets/ipywidgets/__init__.py b/src/spikeinterface/widgets/ipywidgets/__init__.py deleted file mode 100644 index 63d1b3a486..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .amplitudes import AmplitudesPlotter -from .quality_metrics import QualityMetricsPlotter -from .spike_locations import SpikeLocationsPlotter -from .spikes_on_traces import SpikesOnTracesPlotter -from .template_metrics import TemplateMetricsPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter -from .unit_waveforms import UnitWaveformPlotter diff --git a/src/spikeinterface/widgets/ipywidgets/amplitudes.py b/src/spikeinterface/widgets/ipywidgets/amplitudes.py deleted file mode 100644 index dc55b927e0..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/amplitudes.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..amplitudes import AmplitudesWidget -from ..matplotlib.amplitudes import AmplitudesPlotter as MplAmplitudesPlotter - -from IPython.display import display - - -class AmplitudesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - we = data_plot["waveform_extractor"] - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - plot_histograms = widgets.Checkbox( - value=data_plot["plot_histograms"], - description="plot histograms", - disabled=False, - ) - - footer = plot_histograms - - self.controller = {"plot_histograms": plot_histograms} - self.controller.update(unit_controller) - - mpl_plotter = MplAmplitudesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -AmplitudesPlotter.register(AmplitudesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig = fig - self.controller = controller - - self.we = data_plot["waveform_extractor"] - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.fig.clear() - - unit_ids = self.controller["unit_ids"].value - plot_histograms = self.controller["plot_histograms"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_histograms"] = plot_histograms - - backend_kwargs = {} - backend_kwargs["figure"] = self.fig - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py b/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py deleted file mode 100644 index e0eff7f330..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/base_ipywidgets.py +++ /dev/null @@ -1,20 +0,0 @@ -from spikeinterface.widgets.base import BackendPlotter - -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib import gridspec -import numpy as np - - -class IpywidgetsPlotter(BackendPlotter): - backend = "ipywidgets" - backend_kwargs_desc = { - "width_cm": "Width of the figure in cm (default 10)", - "height_cm": "Height of the figure in cm (default 6)", - "display": "If True, widgets are immediately displayed", - } - default_backend_kwargs = {"width_cm": 25, "height_cm": 10, "display": True} - - def check_backend(self): - mpl_backend = mpl.get_backend() - assert "ipympl" in mpl_backend, "To use the 'ipywidgets' backend, you have to set %matplotlib widget" diff --git a/src/spikeinterface/widgets/ipywidgets/metrics.py b/src/spikeinterface/widgets/ipywidgets/metrics.py deleted file mode 100644 index ba6859b2a1..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/metrics.py +++ /dev/null @@ -1,108 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - -from matplotlib.lines import Line2D - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..matplotlib.metrics import MetricsPlotter as MplMetricsPlotter - -from IPython.display import display - - -class MetricsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - if data_plot["unit_ids"] is None: - data_plot["unit_ids"] = [] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - mpl_plotter = MplMetricsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig = fig - self.controller = controller - self.unit_colors = data_plot["unit_colors"] - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - all_units = list(self.unit_colors.keys()) - colors = [] - sizes = [] - for unit in all_units: - color = "gray" if unit not in unit_ids else self.unit_colors[unit] - size = 1 if unit not in unit_ids else 5 - colors.append(color) - sizes.append(size) - - # here we do a trick: we just update colors - if hasattr(self.mpl_plotter, "patches"): - for p in self.mpl_plotter.patches: - p.set_color(colors) - p.set_sizes(sizes) - else: - backend_kwargs = {} - backend_kwargs["figure"] = self.fig - self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) - - if len(unit_ids) > 0: - for l in self.fig.legends: - l.remove() - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=self.unit_colors[unit]) - for unit in unit_ids - ] - labels = unit_ids - self.fig.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/quality_metrics.py b/src/spikeinterface/widgets/ipywidgets/quality_metrics.py deleted file mode 100644 index 3fc368770b..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/quality_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..quality_metrics import QualityMetricsWidget -from .metrics import MetricsPlotter - - -class QualityMetricsPlotter(MetricsPlotter): - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/spike_locations.py b/src/spikeinterface/widgets/ipywidgets/spike_locations.py deleted file mode 100644 index 633eb0ac39..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/spike_locations.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..spike_locations import SpikeLocationsWidget -from ..matplotlib.spike_locations import ( - SpikeLocationsPlotter as MplSpikeLocationsPlotter, -) - -from IPython.display import display - - -class SpikeLocationsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], - list(data_plot["unit_colors"].keys()), - ratios[0] * width_cm, - height_cm, - ) - - self.controller = unit_controller - - mpl_plotter = MplSpikeLocationsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_all_units"] = True - data_plot["plot_legend"] = True - data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py b/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py deleted file mode 100644 index e5a3ebcc71..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/spikes_on_traces.py +++ /dev/null @@ -1,145 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from .base_ipywidgets import IpywidgetsPlotter -from .timeseries import TimeseriesPlotter -from .utils import make_unit_controller - -from ..spikes_on_traces import SpikesOnTracesWidget -from ..matplotlib.spikes_on_traces import SpikesOnTracesPlotter as MplSpikesOnTracesPlotter - -from IPython.display import display - - -class SpikesOnTracesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - ratios = [0.2, 0.8] - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs_ts = backend_kwargs.copy() - backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] - backend_kwargs_ts["display"] = False - height_cm = backend_kwargs["height_cm"] - width_cm = backend_kwargs["width_cm"] - - # plot timeseries - tsplotter = TimeseriesPlotter() - data_plot["timeseries"]["add_legend"] = False - tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) - - ts_w = tsplotter.widget - ts_updater = tsplotter.updater - - we = data_plot["waveform_extractor"] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - self.controller = ts_updater.controller - self.controller.update(unit_controller) - - mpl_plotter = MplSpikesOnTracesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout(center=ts_w, left_sidebar=unit_widget, pane_widths=ratios + [0]) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -SpikesOnTracesPlotter.register(SpikesOnTracesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ts_updater, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - - self.ts_updater = ts_updater - self.ax = ts_updater.ax - self.fig = self.ax.figure - self.controller = controller - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # update ts - # self.ts_updater.__call__(change) - - # update data plot - data_plot = self.data_plot.copy() - data_plot["timeseries"] = self.ts_updater.next_data_plot - data_plot["unit_ids"] = unit_ids - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - self.fig.canvas.draw() - self.fig.canvas.flush_events() - - # t = self.time_slider.value - # d = self.win_sizer.value - - # selected_layer = self.layer_selector.value - # segment_index = self.seg_selector.value - # mode = self.mode_selector.value - - # t_stop = self.t_stops[segment_index] - # if self.actual_segment_index != segment_index: - # # change time_slider limits - # self.time_slider.max = t_stop - # self.actual_segment_index = segment_index - - # # protect limits - # if t >= t_stop - d: - # t = t_stop - d - - # time_range = np.array([t, t+d]) - - # if mode =='line': - # # plot all layer - # layer_keys = self.data_plot['layer_keys'] - # recordings = self.recordings - # clims = None - # elif mode =='map': - # layer_keys = [selected_layer] - # recordings = {selected_layer: self.recordings[selected_layer]} - # clims = {selected_layer: self.data_plot["clims"][selected_layer]} - - # channel_ids = self.data_plot['channel_ids'] - # order = self.data_plot['order'] - # times, list_traces, frame_range, order = _get_trace_list(recordings, channel_ids, time_range, order, - # segment_index) - - # # matplotlib next_data_plot dict update at each call - # data_plot = self.next_data_plot - # data_plot['mode'] = mode - # data_plot['frame_range'] = frame_range - # data_plot['time_range'] = time_range - # data_plot['with_colorbar'] = False - # data_plot['recordings'] = recordings - # data_plot['layer_keys'] = layer_keys - # data_plot['list_traces'] = list_traces - # data_plot['times'] = times - # data_plot['clims'] = clims - - # backend_kwargs = {} - # backend_kwargs['ax'] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - # fig = self.ax.figure - # fig.canvas.draw() - # fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/template_metrics.py b/src/spikeinterface/widgets/ipywidgets/template_metrics.py deleted file mode 100644 index 0aea8ae428..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/template_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..template_metrics import TemplateMetricsWidget -from .metrics import MetricsPlotter - - -class TemplateMetricsPlotter(MetricsPlotter): - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/timeseries.py b/src/spikeinterface/widgets/ipywidgets/timeseries.py deleted file mode 100644 index 2448166f16..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/timeseries.py +++ /dev/null @@ -1,232 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - -from ...core import order_channels_by_depth - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_timeseries_controller, make_channel_controller, make_scale_controller - -from ..timeseries import TimeseriesWidget, _get_trace_list -from ..matplotlib.timeseries import TimeseriesPlotter as MplTimeseriesPlotter - -from IPython.display import display - - -class TimeseriesPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - recordings = data_plot["recordings"] - - # first layer - rec0 = recordings[data_plot["layer_keys"][0]] - - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - ratios = [0.1, 0.8, 0.2] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=(0.9 * ratios[1] * width_cm * cm, height_cm * cm)) - plt.show() - - t_start = 0.0 - t_stop = rec0.get_num_samples(segment_index=0) / rec0.get_sampling_frequency() - - ts_widget, ts_controller = make_timeseries_controller( - t_start, - t_stop, - data_plot["layer_keys"], - rec0.get_num_segments(), - data_plot["time_range"], - data_plot["mode"], - False, - width_cm, - ) - - ch_widget, ch_controller = make_channel_controller(rec0, width_cm=ratios[2] * width_cm, height_cm=height_cm) - - scale_widget, scale_controller = make_scale_controller(width_cm=ratios[0] * width_cm, height_cm=height_cm) - - self.controller = ts_controller - self.controller.update(ch_controller) - self.controller.update(scale_controller) - - mpl_plotter = MplTimeseriesPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - if isinstance(w, widgets.Button): - w.on_click(self.updater) - else: - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - footer=ts_widget, - left_sidebar=scale_widget, - right_sidebar=ch_widget, - pane_heights=[0, 6, 1], - pane_widths=ratios, - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -TimeseriesPlotter.register(TimeseriesWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - - self.ax = ax - self.controller = controller - - self.recordings = data_plot["recordings"] - self.return_scaled = data_plot["return_scaled"] - self.next_data_plot = data_plot.copy() - self.list_traces = None - - self.actual_segment_index = self.controller["segment_index"].value - - self.rec0 = self.recordings[self.data_plot["layer_keys"][0]] - self.t_stops = [ - self.rec0.get_num_samples(segment_index=seg_index) / self.rec0.get_sampling_frequency() - for seg_index in range(self.rec0.get_num_segments()) - ] - - def __call__(self, change): - self.ax.clear() - - # if changing the layer_key, no need to retrieve and process traces - retrieve_traces = True - scale_up = False - scale_down = False - if change is not None: - for cname, c in self.controller.items(): - if isinstance(change, dict): - if change["owner"] is c and cname == "layer_key": - retrieve_traces = False - elif isinstance(change, widgets.Button): - if change is c and cname == "plus": - scale_up = True - if change is c and cname == "minus": - scale_down = True - - t_start = self.controller["t_start"].value - window = self.controller["window"].value - layer_key = self.controller["layer_key"].value - segment_index = self.controller["segment_index"].value - mode = self.controller["mode"].value - chan_start, chan_stop = self.controller["channel_inds"].value - - if mode == "line": - self.controller["all_layers"].layout.visibility = "visible" - all_layers = self.controller["all_layers"].value - elif mode == "map": - self.controller["all_layers"].layout.visibility = "hidden" - all_layers = False - - if all_layers: - self.controller["layer_key"].layout.visibility = "hidden" - else: - self.controller["layer_key"].layout.visibility = "visible" - - if chan_start == chan_stop: - chan_stop += 1 - channel_indices = np.arange(chan_start, chan_stop) - - t_stop = self.t_stops[segment_index] - if self.actual_segment_index != segment_index: - # change time_slider limits - self.controller["t_start"].max = t_stop - self.actual_segment_index = segment_index - - # protect limits - if t_start >= t_stop - window: - t_start = t_stop - window - - time_range = np.array([t_start, t_start + window]) - data_plot = self.next_data_plot - - if retrieve_traces: - all_channel_ids = self.recordings[list(self.recordings.keys())[0]].channel_ids - if self.data_plot["order"] is not None: - all_channel_ids = all_channel_ids[self.data_plot["order"]] - channel_ids = all_channel_ids[channel_indices] - if self.data_plot["order_channel_by_depth"]: - order, _ = order_channels_by_depth(self.rec0, channel_ids) - else: - order = None - times, list_traces, frame_range, channel_ids = _get_trace_list( - self.recordings, channel_ids, time_range, segment_index, order, self.return_scaled - ) - self.list_traces = list_traces - else: - times = data_plot["times"] - list_traces = data_plot["list_traces"] - frame_range = data_plot["frame_range"] - channel_ids = data_plot["channel_ids"] - - if all_layers: - layer_keys = self.data_plot["layer_keys"] - recordings = self.recordings - list_traces_plot = self.list_traces - else: - layer_keys = [layer_key] - recordings = {layer_key: self.recordings[layer_key]} - list_traces_plot = [self.list_traces[list(self.recordings.keys()).index(layer_key)]] - - if scale_up: - if mode == "line": - data_plot["vspacing"] *= 0.8 - elif mode == "map": - data_plot["clims"] = { - layer: (1.2 * val[0], 1.2 * val[1]) for layer, val in self.data_plot["clims"].items() - } - if scale_down: - if mode == "line": - data_plot["vspacing"] *= 1.2 - elif mode == "map": - data_plot["clims"] = { - layer: (0.8 * val[0], 0.8 * val[1]) for layer, val in self.data_plot["clims"].items() - } - - self.next_data_plot["vspacing"] = data_plot["vspacing"] - self.next_data_plot["clims"] = data_plot["clims"] - - if mode == "line": - clims = None - elif mode == "map": - clims = {layer_key: self.data_plot["clims"][layer_key]} - - # matplotlib next_data_plot dict update at each call - data_plot["mode"] = mode - data_plot["frame_range"] = frame_range - data_plot["time_range"] = time_range - data_plot["with_colorbar"] = False - data_plot["recordings"] = recordings - data_plot["layer_keys"] = layer_keys - data_plot["list_traces"] = list_traces_plot - data_plot["times"] = times - data_plot["clims"] = clims - data_plot["channel_ids"] = channel_ids - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - - fig = self.ax.figure - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/unit_locations.py b/src/spikeinterface/widgets/ipywidgets/unit_locations.py deleted file mode 100644 index e78c0d8fe5..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_locations.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..unit_locations import UnitLocationsWidget -from ..matplotlib.unit_locations import UnitLocationsPlotter as MplUnitLocationsPlotter - -from IPython.display import display - - -class UnitLocationsPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.15, 0.85] - - with plt.ioff(): - output = widgets.Output() - with output: - fig, ax = plt.subplots(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], list(data_plot["unit_colors"].keys()), ratios[0] * width_cm, height_cm - ) - - self.controller = unit_controller - - mpl_plotter = MplUnitLocationsPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig.canvas, - left_sidebar=unit_widget, - pane_widths=ratios + [0], - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -UnitLocationsPlotter.register(UnitLocationsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, ax, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.ax = ax - self.controller = controller - - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.ax.clear() - - unit_ids = self.controller["unit_ids"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["plot_all_units"] = True - data_plot["plot_legend"] = True - data_plot["hide_axis"] = True - - backend_kwargs = {} - backend_kwargs["ax"] = self.ax - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - fig = self.ax.get_figure() - fig.canvas.draw() - fig.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/unit_templates.py b/src/spikeinterface/widgets/ipywidgets/unit_templates.py deleted file mode 100644 index 41da9d8cd3..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_templates.py +++ /dev/null @@ -1,11 +0,0 @@ -from ..unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter - - -class UnitTemplatesPlotter(UnitWaveformPlotter): - def do_plot(self, data_plot, **backend_kwargs): - super().do_plot(data_plot, **backend_kwargs) - self.controller["plot_templates"].layout.visibility = "hidden" - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py b/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py deleted file mode 100644 index 012b46038a..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/unit_waveforms.py +++ /dev/null @@ -1,169 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt -import ipywidgets.widgets as widgets - - -from ..base import to_attr - -from .base_ipywidgets import IpywidgetsPlotter -from .utils import make_unit_controller - -from ..unit_waveforms import UnitWaveformsWidget -from ..matplotlib.unit_waveforms import UnitWaveformPlotter as MplUnitWaveformPlotter - -from IPython.display import display - - -class UnitWaveformPlotter(IpywidgetsPlotter): - def do_plot(self, data_plot, **backend_kwargs): - cm = 1 / 2.54 - we = data_plot["waveform_extractor"] - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] - height_cm = backend_kwargs["height_cm"] - - ratios = [0.1, 0.7, 0.2] - - with plt.ioff(): - output1 = widgets.Output() - with output1: - fig_wf = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) - plt.show() - output2 = widgets.Output() - with output2: - fig_probe, ax_probe = plt.subplots(figsize=((ratios[2] * width_cm) * cm, height_cm * cm)) - plt.show() - - data_plot["unit_ids"] = data_plot["unit_ids"][:1] - unit_widget, unit_controller = make_unit_controller( - data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm - ) - - same_axis_button = widgets.Checkbox( - value=False, - description="same axis", - disabled=False, - ) - - plot_templates_button = widgets.Checkbox( - value=True, - description="plot templates", - disabled=False, - ) - - hide_axis_button = widgets.Checkbox( - value=True, - description="hide axis", - disabled=False, - ) - - footer = widgets.HBox([same_axis_button, plot_templates_button, hide_axis_button]) - - self.controller = { - "same_axis": same_axis_button, - "plot_templates": plot_templates_button, - "hide_axis": hide_axis_button, - } - self.controller.update(unit_controller) - - mpl_plotter = MplUnitWaveformPlotter() - - self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) - for w in self.controller.values(): - w.observe(self.updater) - - self.widget = widgets.AppLayout( - center=fig_wf.canvas, - left_sidebar=unit_widget, - right_sidebar=fig_probe.canvas, - pane_widths=ratios, - footer=footer, - ) - - # a first update - self.updater(None) - - if backend_kwargs["display"]: - self.check_backend() - display(self.widget) - - -UnitWaveformPlotter.register(UnitWaveformsWidget) - - -class PlotUpdater: - def __init__(self, data_plot, mpl_plotter, fig_wf, ax_probe, controller): - self.data_plot = data_plot - self.mpl_plotter = mpl_plotter - self.fig_wf = fig_wf - self.ax_probe = ax_probe - self.controller = controller - - self.we = data_plot["waveform_extractor"] - self.next_data_plot = data_plot.copy() - - def __call__(self, change): - self.fig_wf.clear() - self.ax_probe.clear() - - unit_ids = self.controller["unit_ids"].value - same_axis = self.controller["same_axis"].value - plot_templates = self.controller["plot_templates"].value - hide_axis = self.controller["hide_axis"].value - - # matplotlib next_data_plot dict update at each call - data_plot = self.next_data_plot - data_plot["unit_ids"] = unit_ids - data_plot["templates"] = self.we.get_all_templates(unit_ids=unit_ids) - data_plot["template_stds"] = self.we.get_all_templates(unit_ids=unit_ids, mode="std") - data_plot["same_axis"] = same_axis - data_plot["plot_templates"] = plot_templates - if data_plot["plot_waveforms"]: - data_plot["wfs_by_ids"] = {unit_id: self.we.get_waveforms(unit_id) for unit_id in unit_ids} - - backend_kwargs = {} - - if same_axis: - backend_kwargs["ax"] = self.fig_wf.add_subplot() - data_plot["set_title"] = False - else: - backend_kwargs["figure"] = self.fig_wf - - self.mpl_plotter.do_plot(data_plot, **backend_kwargs) - if same_axis: - self.mpl_plotter.ax.axis("equal") - if hide_axis: - self.mpl_plotter.ax.axis("off") - else: - if hide_axis: - for i in range(len(unit_ids)): - ax = self.mpl_plotter.axes.flatten()[i] - ax.axis("off") - - # update probe plot - channel_locations = self.we.get_channel_locations() - self.ax_probe.plot( - channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 - ) - self.ax_probe.axis("off") - self.ax_probe.axis("equal") - - for unit in unit_ids: - channel_inds = data_plot["sparsity"].unit_id_to_channel_indices[unit] - self.ax_probe.plot( - channel_locations[channel_inds, 0], - channel_locations[channel_inds, 1], - ls="", - marker="o", - markersize=3, - color=self.next_data_plot["unit_colors"][unit], - ) - self.ax_probe.set_xlim(np.min(channel_locations[:, 0]) - 10, np.max(channel_locations[:, 0]) + 10) - fig_probe = self.ax_probe.get_figure() - - self.fig_wf.canvas.draw() - self.fig_wf.canvas.flush_events() - fig_probe.canvas.draw() - fig_probe.canvas.flush_events() diff --git a/src/spikeinterface/widgets/ipywidgets/utils.py b/src/spikeinterface/widgets/ipywidgets/utils.py deleted file mode 100644 index f4b86c3fc2..0000000000 --- a/src/spikeinterface/widgets/ipywidgets/utils.py +++ /dev/null @@ -1,97 +0,0 @@ -import ipywidgets.widgets as widgets -import numpy as np - - -def make_timeseries_controller(t_start, t_stop, layer_keys, num_segments, time_range, mode, all_layers, width_cm): - time_slider = widgets.FloatSlider( - orientation="horizontal", - description="time:", - value=time_range[0], - min=t_start, - max=t_stop, - continuous_update=False, - layout=widgets.Layout(width=f"{width_cm}cm"), - ) - layer_selector = widgets.Dropdown(description="layer", options=layer_keys) - segment_selector = widgets.Dropdown(description="segment", options=list(range(num_segments))) - window_sizer = widgets.BoundedFloatText(value=np.diff(time_range)[0], step=0.1, min=0.005, description="win (s)") - mode_selector = widgets.Dropdown(options=["line", "map"], description="mode", value=mode) - all_layers = widgets.Checkbox(description="plot all layers", value=all_layers) - - controller = { - "layer_key": layer_selector, - "segment_index": segment_selector, - "window": window_sizer, - "t_start": time_slider, - "mode": mode_selector, - "all_layers": all_layers, - } - widget = widgets.VBox( - [time_slider, widgets.HBox([all_layers, layer_selector, segment_selector, window_sizer, mode_selector])] - ) - - return widget, controller - - -def make_unit_controller(unit_ids, all_unit_ids, width_cm, height_cm): - unit_label = widgets.Label(value="units:") - - unit_selector = widgets.SelectMultiple( - options=all_unit_ids, - value=list(unit_ids), - disabled=False, - layout=widgets.Layout(width=f"{width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"unit_ids": unit_selector} - widget = widgets.VBox([unit_label, unit_selector]) - - return widget, controller - - -def make_channel_controller(recording, width_cm, height_cm): - channel_label = widgets.Label("channel indices:", layout=widgets.Layout(justify_content="center")) - channel_selector = widgets.IntRangeSlider( - value=[0, recording.get_num_channels()], - min=0, - max=recording.get_num_channels(), - step=1, - disabled=False, - continuous_update=False, - orientation="vertical", - readout=True, - readout_format="d", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{height_cm}cm"), - ) - - controller = {"channel_inds": channel_selector} - widget = widgets.VBox([channel_label, channel_selector]) - - return widget, controller - - -def make_scale_controller(width_cm, height_cm): - scale_label = widgets.Label("Scale", layout=widgets.Layout(justify_content="center")) - - plus_selector = widgets.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Increase scale", - icon="arrow-up", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) - - minus_selector = widgets.Button( - description="", - disabled=False, - button_style="", # 'success', 'info', 'warning', 'danger' or '' - tooltip="Decrease scale", - icon="arrow-down", - layout=widgets.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), - ) - - controller = {"plus": plus_selector, "minus": minus_selector} - widget = widgets.VBox([scale_label, plus_selector, minus_selector]) - - return widget, controller diff --git a/src/spikeinterface/widgets/matplotlib/__init__.py b/src/spikeinterface/widgets/matplotlib/__init__.py deleted file mode 100644 index 525396e30d..0000000000 --- a/src/spikeinterface/widgets/matplotlib/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .all_amplitudes_distributions import AllAmplitudesDistributionsPlotter -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .quality_metrics import QualityMetricsPlotter -from .motion import MotionPlotter -from .spike_locations import SpikeLocationsPlotter -from .spikes_on_traces import SpikesOnTracesPlotter -from .template_metrics import TemplateMetricsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter -from .unit_waveforms_density_map import UnitWaveformDensityMapPlotter -from .unit_depths import UnitDepthsPlotter -from .unit_summary import UnitSummaryPlotter diff --git a/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py b/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py deleted file mode 100644 index 6985d2167a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/all_amplitudes_distributions.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..all_amplitudes_distributions import AllAmplitudesDistributionsWidget -from .base_mpl import MplPlotter - - -class AllAmplitudesDistributionsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.make_mpl_figure(**backend_kwargs) - - ax = self.ax - - unit_amps = [] - for i, unit_id in enumerate(dp.unit_ids): - amps = [] - for segment_index in range(dp.num_segments): - amps.append(dp.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) - - for i, pc in enumerate(parts["bodies"]): - color = dp.unit_colors[dp.unit_ids[i]] - pc.set_facecolor(color) - pc.set_edgecolor("black") - pc.set_alpha(1) - - ax.set_xticks(np.arange(len(dp.unit_ids)) + 1) - ax.set_xticklabels([str(unit_id) for unit_id in dp.unit_ids]) - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -AllAmplitudesDistributionsPlotter.register(AllAmplitudesDistributionsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/amplitudes.py b/src/spikeinterface/widgets/matplotlib/amplitudes.py deleted file mode 100644 index 747709211a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/amplitudes.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..amplitudes import AmplitudesWidget -from .base_mpl import MplPlotter - - -class AmplitudesPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None: - axes = backend_kwargs["axes"] - if dp.plot_histograms: - assert np.asarray(axes).size == 2 - else: - assert np.asarray(axes).size == 1 - elif backend_kwargs["ax"] is not None: - assert not dp.plot_histograms - else: - if dp.plot_histograms: - backend_kwargs["num_axes"] = 2 - backend_kwargs["ncols"] = 2 - else: - backend_kwargs["num_axes"] = None - - self.make_mpl_figure(**backend_kwargs) - - scatter_ax = self.axes.flatten()[0] - - for unit_id in dp.unit_ids: - spiketrains = dp.spiketrains[unit_id] - amps = dp.amplitudes[unit_id] - scatter_ax.scatter(spiketrains, amps, color=dp.unit_colors[unit_id], s=3, alpha=1, label=unit_id) - - if dp.plot_histograms: - if dp.bins is None: - bins = int(len(spiketrains) / 30) - else: - bins = dp.bins - ax_hist = self.axes.flatten()[1] - ax_hist.hist(amps, bins=bins, orientation="horizontal", color=dp.unit_colors[unit_id], alpha=0.8) - - if dp.plot_histograms: - ax_hist = self.axes.flatten()[1] - ax_hist.set_ylim(scatter_ax.get_ylim()) - ax_hist.axis("off") - self.figure.tight_layout() - - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - scatter_ax.set_xlim(0, dp.total_duration) - scatter_ax.set_xlabel("Times [s]") - scatter_ax.set_ylabel(f"Amplitude") - scatter_ax.spines["top"].set_visible(False) - scatter_ax.spines["right"].set_visible(False) - self.figure.subplots_adjust(bottom=0.1, top=0.9, left=0.1) - - -AmplitudesPlotter.register(AmplitudesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/autocorrelograms.py b/src/spikeinterface/widgets/matplotlib/autocorrelograms.py deleted file mode 100644 index 9245ef6881..0000000000 --- a/src/spikeinterface/widgets/matplotlib/autocorrelograms.py +++ /dev/null @@ -1,30 +0,0 @@ -from ..base import to_attr -from ..autocorrelograms import AutoCorrelogramsWidget -from .base_mpl import MplPlotter - - -class AutoCorrelogramsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = len(dp.unit_ids) - - self.make_mpl_figure(**backend_kwargs) - - bins = dp.bins - unit_ids = dp.unit_ids - correlograms = dp.correlograms - bin_width = bins[1] - bins[0] - - for i, unit_id in enumerate(unit_ids): - ccg = correlograms[i, i] - ax = self.axes.flatten()[i] - if dp.unit_colors is None: - color = "g" - else: - color = dp.unit_colors[unit_id] - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - ax.set_title(str(unit_id)) - - -AutoCorrelogramsPlotter.register(AutoCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/base_mpl.py b/src/spikeinterface/widgets/matplotlib/base_mpl.py deleted file mode 100644 index 266adc8782..0000000000 --- a/src/spikeinterface/widgets/matplotlib/base_mpl.py +++ /dev/null @@ -1,102 +0,0 @@ -from spikeinterface.widgets.base import BackendPlotter - -import matplotlib.pyplot as plt -import numpy as np - - -class MplPlotter(BackendPlotter): - backend = "matplotlib" - backend_kwargs_desc = { - "figure": "Matplotlib figure. When None, it is created. Default None", - "ax": "Single matplotlib axis. When None, it is created. Default None", - "axes": "Multiple matplotlib axes. When None, they is created. Default None", - "ncols": "Number of columns to create in subplots. Default 5", - "figsize": "Size of matplotlib figure. Default None", - "figtitle": "The figure title. Default None", - } - default_backend_kwargs = {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None} - - def make_mpl_figure(self, figure=None, ax=None, axes=None, ncols=None, num_axes=None, figsize=None, figtitle=None): - """ - figure/ax/axes : only one of then can be not None - """ - if figure is not None: - assert ax is None and axes is None, "figure/ax/axes : only one of then can be not None" - if num_axes is None: - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - axes = [] - nrows = int(np.ceil(num_axes / ncols)) - axes = np.full((nrows, ncols), fill_value=None, dtype=object) - for i in range(num_axes): - ax = figure.add_subplot(nrows, ncols, i + 1) - r = i // ncols - c = i % ncols - axes[r, c] = ax - elif ax is not None: - assert figure is None and axes is None, "figure/ax/axes : only one of then can be not None" - figure = ax.get_figure() - axes = np.array([[ax]]) - elif axes is not None: - assert figure is None and ax is None, "figure/ax/axes : only one of then can be not None" - axes = np.asarray(axes) - figure = axes.flatten()[0].get_figure() - else: - # 'figure/ax/axes are all None - if num_axes is None: - # one fig with one ax - figure, ax = plt.subplots(figsize=figsize) - axes = np.array([[ax]]) - else: - if num_axes == 0: - # one figure without plots (diffred subplot creation with - figure = plt.figure(figsize=figsize) - ax = None - axes = None - elif num_axes == 1: - figure = plt.figure(figsize=figsize) - ax = figure.add_subplot(111) - axes = np.array([[ax]]) - else: - assert ncols is not None - if num_axes < ncols: - ncols = num_axes - nrows = int(np.ceil(num_axes / ncols)) - figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize) - ax = None - # remove extra axes - if ncols * nrows > num_axes: - for i, extra_ax in enumerate(axes.flatten()): - if i >= num_axes: - extra_ax.remove() - r = i // ncols - c = i % ncols - axes[r, c] = None - - self.figure = figure - self.ax = ax - # axes is always a 2D array of ax - self.axes = axes - - if figtitle is not None: - self.figure.suptitle(figtitle) - - -class to_attr(object): - def __init__(self, d): - """ - Helper function that transform a dict into - an object where attributes are the keys of the dict - - d = {'a': 1, 'b': 'yep'} - o = to_attr(d) - print(o.a, o.b) - """ - object.__init__(self) - object.__setattr__(self, "__d", d) - - def __getattribute__(self, k): - d = object.__getattribute__(self, "__d") - return d[k] diff --git a/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py b/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py deleted file mode 100644 index 24ecdcdffc..0000000000 --- a/src/spikeinterface/widgets/matplotlib/crosscorrelograms.py +++ /dev/null @@ -1,39 +0,0 @@ -from ..base import to_attr -from ..crosscorrelograms import CrossCorrelogramsWidget -from .base_mpl import MplPlotter - - -class CrossCorrelogramsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["ncols"] = len(dp.unit_ids) - backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) - - self.make_mpl_figure(**backend_kwargs) - assert self.axes.ndim == 2 - - bins = dp.bins - unit_ids = dp.unit_ids - correlograms = dp.correlograms - bin_width = bins[1] - bins[0] - - for i, unit_id1 in enumerate(unit_ids): - for j, unit_id2 in enumerate(unit_ids): - ccg = correlograms[i, j] - ax = self.axes[i, j] - if i == j: - if dp.unit_colors is None: - color = "g" - else: - color = dp.unit_colors[unit_id1] - else: - color = "k" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - - for i, unit_id in enumerate(unit_ids): - self.axes[0, i].set_title(str(unit_id)) - self.axes[-1, i].set_xlabel("CCG (ms)") - - -CrossCorrelogramsPlotter.register(CrossCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/metrics.py b/src/spikeinterface/widgets/matplotlib/metrics.py deleted file mode 100644 index cec4c11644..0000000000 --- a/src/spikeinterface/widgets/matplotlib/metrics.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np - -from ..base import to_attr -from .base_mpl import MplPlotter - - -class MetricsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - metrics = dp.metrics - num_metrics = len(metrics.columns) - - if "figsize" not in backend_kwargs: - backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = num_metrics**2 - backend_kwargs["ncols"] = num_metrics - - all_unit_ids = metrics.index.values - - self.make_mpl_figure(**backend_kwargs) - assert self.axes.ndim == 2 - - if dp.unit_ids is None: - colors = ["gray"] * len(all_unit_ids) - else: - colors = [] - for unit in all_unit_ids: - color = "gray" if unit not in dp.unit_ids else dp.unit_colors[unit] - colors.append(color) - - self.patches = [] - for i, m1 in enumerate(metrics.columns): - for j, m2 in enumerate(metrics.columns): - if i == j: - self.axes[i, j].hist(metrics[m1], color="gray") - else: - p = self.axes[i, j].scatter(metrics[m1], metrics[m2], c=colors, s=3, marker="o") - self.patches.append(p) - if i == num_metrics - 1: - self.axes[i, j].set_xlabel(m2, fontsize=10) - if j == 0: - self.axes[i, j].set_ylabel(m1, fontsize=10) - self.axes[i, j].set_xticklabels([]) - self.axes[i, j].set_yticklabels([]) - self.axes[i, j].spines["top"].set_visible(False) - self.axes[i, j].spines["right"].set_visible(False) - - self.figure.subplots_adjust(top=0.8, wspace=0.2, hspace=0.2) diff --git a/src/spikeinterface/widgets/matplotlib/motion.py b/src/spikeinterface/widgets/matplotlib/motion.py deleted file mode 100644 index 8a89351c8a..0000000000 --- a/src/spikeinterface/widgets/matplotlib/motion.py +++ /dev/null @@ -1,129 +0,0 @@ -from ..base import to_attr -from ..motion import MotionWidget -from .base_mpl import MplPlotter - -import numpy as np -from matplotlib.colors import Normalize - - -class MotionPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - assert backend_kwargs["axes"] is None - assert backend_kwargs["ax"] is None - - self.make_mpl_figure(**backend_kwargs) - fig = self.figure - fig.clear() - - is_rigid = dp.motion.shape[1] == 1 - - gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) - ax0 = fig.add_subplot(gs[0, 0]) - ax1 = fig.add_subplot(gs[0, 1]) - ax2 = fig.add_subplot(gs[1, 0]) - if not is_rigid: - ax3 = fig.add_subplot(gs[1, 1]) - ax1.sharex(ax0) - ax1.sharey(ax0) - - if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 - else: - motion_lim = dp.motion_lim - - if dp.times is None: - temporal_bins_plot = dp.temporal_bins - x = dp.peaks["sample_index"] / dp.sampling_frequency - else: - # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + dp.times[0] - x = dp.times[dp.peaks["sample_index"]] - - corrected_location = correct_motion_on_peaks( - dp.peaks, - dp.peak_locations, - dp.sampling_frequency, - dp.motion, - dp.temporal_bins, - dp.spatial_bins, - direction="y", - ) - - y = dp.peak_locations["y"] - y2 = corrected_location["y"] - if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] - - if dp.color_amplitude: - amps = dp.peaks["amplitude"] - amps_abs = np.abs(amps) - q_95 = np.quantile(amps_abs, 0.95) - if dp.scatter_decimate is not None: - amps = amps[:: dp.scatter_decimate] - amps_abs = amps_abs[:: dp.scatter_decimate] - cmap = plt.get_cmap(dp.amplitude_cmap) - if dp.amplitude_clim is None: - amps = amps_abs - amps /= q_95 - c = cmap(amps) - else: - norm_function = Normalize(vmin=dp.amplitude_clim[0], vmax=dp.amplitude_clim[1], clip=True) - c = cmap(norm_function(amps)) - color_kwargs = dict( - color=None, - c=c, - alpha=dp.amplitude_alpha, - ) - else: - color_kwargs = dict(color="k", c=None, alpha=dp.amplitude_alpha) - - ax0.scatter(x, y, s=1, **color_kwargs) - if dp.depth_lim is not None: - ax0.set_ylim(*dp.depth_lim) - ax0.set_title("Peak depth") - ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [um]") - - ax1.scatter(x, y2, s=1, **color_kwargs) - ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [um]") - ax1.set_title("Corrected peak depth") - - ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") - ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("Motion [um]") - ax2.set_title("Motion vectors") - axes = [ax0, ax1, ax2] - - if not is_rigid: - im = ax3.imshow( - dp.motion.T, - aspect="auto", - origin="lower", - extent=( - temporal_bins_plot[0], - temporal_bins_plot[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], - ), - ) - im.set_clim(-motion_lim, motion_lim) - cbar = fig.colorbar(im) - cbar.ax.set_xlabel("motion [um]") - ax3.set_xlabel("Times [s]") - ax3.set_ylabel("Depth [um]") - ax3.set_title("Motion vectors") - axes.append(ax3) - self.axes = np.array(axes) - - -MotionPlotter.register(MotionWidget) diff --git a/src/spikeinterface/widgets/matplotlib/quality_metrics.py b/src/spikeinterface/widgets/matplotlib/quality_metrics.py deleted file mode 100644 index 3fc368770b..0000000000 --- a/src/spikeinterface/widgets/matplotlib/quality_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..quality_metrics import QualityMetricsWidget -from .metrics import MetricsPlotter - - -class QualityMetricsPlotter(MetricsPlotter): - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/spike_locations.py b/src/spikeinterface/widgets/matplotlib/spike_locations.py deleted file mode 100644 index 5c74df3fc8..0000000000 --- a/src/spikeinterface/widgets/matplotlib/spike_locations.py +++ /dev/null @@ -1,96 +0,0 @@ -from probeinterface import ProbeGroup -from probeinterface.plotting import plot_probe - -import numpy as np - -from ..base import to_attr -from ..spike_locations import SpikeLocationsWidget, estimate_axis_lims -from .base_mpl import MplPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class SpikeLocationsPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - spike_locations = dp.spike_locations - - probegroup = ProbeGroup.from_dict(dp.probegroup_dict) - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if dp.with_channel_ids: - text_on_contact = dp.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=self.ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - self.ax.set_title("") - - if dp.plot_all_units: - unit_colors = {} - unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" - else: - unit_colors[unit] = dp.unit_colors[unit] - else: - unit_ids = dp.unit_ids - unit_colors = dp.unit_colors - labels = dp.unit_ids - - for i, unit in enumerate(unit_ids): - locs = spike_locations[unit] - - zorder = 5 if unit in dp.unit_ids else 3 - self.ax.scatter(locs["x"], locs["y"], s=2, alpha=0.3, color=unit_colors[unit], zorder=zorder) - - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) - for unit in dp.unit_ids - ] - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - # set proper axis limits - xlims, ylims = estimate_axis_lims(spike_locations) - - ax_xlims = list(self.ax.get_xlim()) - ax_ylims = list(self.ax.get_ylim()) - - ax_xlims[0] = xlims[0] if xlims[0] < ax_xlims[0] else ax_xlims[0] - ax_xlims[1] = xlims[1] if xlims[1] > ax_xlims[1] else ax_xlims[1] - ax_ylims[0] = ylims[0] if ylims[0] < ax_ylims[0] else ax_ylims[0] - ax_ylims[1] = ylims[1] if ylims[1] > ax_ylims[1] else ax_ylims[1] - - self.ax.set_xlim(ax_xlims) - self.ax.set_ylim(ax_ylims) - if dp.hide_axis: - self.ax.axis("off") - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py b/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py deleted file mode 100644 index d620c8f28f..0000000000 --- a/src/spikeinterface/widgets/matplotlib/spikes_on_traces.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..spikes_on_traces import SpikesOnTracesWidget -from .base_mpl import MplPlotter -from .timeseries import TimeseriesPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class SpikesOnTracesPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - # first plot time series - tsplotter = TimeseriesPlotter() - data_plot["timeseries"]["add_legend"] = False - tsplotter.do_plot(dp.timeseries, **backend_kwargs) - self.ax = tsplotter.ax - self.axes = tsplotter.axes - self.figure = tsplotter.figure - - ax = self.ax - - we = dp.waveform_extractor - sorting = dp.waveform_extractor.sorting - frame_range = dp.timeseries["frame_range"] - segment_index = dp.timeseries["segment_index"] - min_y = np.min(dp.timeseries["channel_locations"][:, 1]) - max_y = np.max(dp.timeseries["channel_locations"][:, 1]) - - n = len(dp.timeseries["channel_ids"]) - order = dp.timeseries["order"] - if order is None: - order = np.arange(n) - - if ax.get_legend() is not None: - ax.get_legend().remove() - - # loop through units and plot a scatter of spikes at estimated location - handles = [] - labels = [] - - for unit in dp.unit_ids: - spike_frames = sorting.get_unit_spike_train(unit, segment_index=segment_index) - spike_start, spike_end = np.searchsorted(spike_frames, frame_range) - - chan_ids = dp.sparsity.unit_id_to_channel_ids[unit] - - spike_frames_to_plot = spike_frames[spike_start:spike_end] - - if dp.timeseries["mode"] == "map": - spike_times_to_plot = sorting.get_unit_spike_train( - unit, segment_index=segment_index, return_times=True - )[spike_start:spike_end] - unit_y_loc = min_y + max_y - dp.unit_locations[unit][1] - # markers = np.ones_like(spike_frames_to_plot) * (min_y + max_y - dp.unit_locations[unit][1]) - width = 2 * 1e-3 - ellipse_kwargs = dict(width=width, height=10, fc="none", ec=dp.unit_colors[unit], lw=2) - patches = [Ellipse((s, unit_y_loc), **ellipse_kwargs) for s in spike_times_to_plot] - for p in patches: - ax.add_patch(p) - handles.append( - Line2D( - [0], - [0], - ls="", - marker="o", - markersize=5, - markeredgewidth=2, - markeredgecolor=dp.unit_colors[unit], - markerfacecolor="none", - ) - ) - labels.append(unit) - else: - # construct waveforms - label_set = False - if len(spike_frames_to_plot) > 0: - vspacing = dp.timeseries["vspacing"] - traces = dp.timeseries["list_traces"][0] - waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) - - times = dp.timeseries["times"][waveform_idxs] - # discontinuity - times[:, -1] = np.nan - times_r = times.reshape(times.shape[0] * times.shape[1]) - waveforms = traces[waveform_idxs] # [:, :, order] - waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - - for i, chan_id in enumerate(dp.timeseries["channel_ids"]): - offset = vspacing * i - if chan_id in chan_ids: - l = ax.plot(times_r, offset + waveforms_r[:, i], color=dp.unit_colors[unit]) - if not label_set: - handles.append(l[0]) - labels.append(unit) - label_set = True - ax.legend(handles, labels) - - -SpikesOnTracesPlotter.register(SpikesOnTracesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/template_metrics.py b/src/spikeinterface/widgets/matplotlib/template_metrics.py deleted file mode 100644 index 0aea8ae428..0000000000 --- a/src/spikeinterface/widgets/matplotlib/template_metrics.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..template_metrics import TemplateMetricsWidget -from .metrics import MetricsPlotter - - -class TemplateMetricsPlotter(MetricsPlotter): - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/template_similarity.py b/src/spikeinterface/widgets/matplotlib/template_similarity.py deleted file mode 100644 index 1e0a2e6fae..0000000000 --- a/src/spikeinterface/widgets/matplotlib/template_similarity.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..template_similarity import TemplateSimilarityWidget -from .base_mpl import MplPlotter - - -class TemplateSimilarityPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - im = self.ax.matshow(dp.similarity, cmap=dp.cmap) - - if dp.show_unit_ticks: - # Major ticks - self.ax.set_xticks(np.arange(0, len(dp.unit_ids))) - self.ax.set_yticks(np.arange(0, len(dp.unit_ids))) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - self.ax.set_yticklabels(dp.unit_ids, fontsize=12) - self.ax.set_xticklabels(dp.unit_ids, fontsize=12) - if dp.show_colorbar: - self.figure.colorbar(im) - - -TemplateSimilarityPlotter.register(TemplateSimilarityWidget) diff --git a/src/spikeinterface/widgets/matplotlib/timeseries.py b/src/spikeinterface/widgets/matplotlib/timeseries.py deleted file mode 100644 index 0a887b559f..0000000000 --- a/src/spikeinterface/widgets/matplotlib/timeseries.py +++ /dev/null @@ -1,70 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..timeseries import TimeseriesWidget -from .base_mpl import MplPlotter -from matplotlib.ticker import MaxNLocator - - -class TimeseriesPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - ax = self.ax - n = len(dp.channel_ids) - if dp.channel_locations is not None: - y_locs = dp.channel_locations[:, 1] - else: - y_locs = np.arange(n) - min_y = np.min(y_locs) - max_y = np.max(y_locs) - - if dp.mode == "line": - offset = dp.vspacing * (n - 1) - - for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - for i, chan_id in enumerate(dp.channel_ids): - offset = dp.vspacing * i - color = dp.colors[layer_key][chan_id] - ax.plot(dp.times, offset + traces[:, i], color=color) - ax.get_lines()[-1].set_label(layer_key) - - if dp.show_channel_ids: - ax.set_yticks(np.arange(n) * dp.vspacing) - channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) - ax.set_yticklabels(channel_labels) - else: - ax.get_yaxis().set_visible(False) - - ax.set_xlim(*dp.time_range) - ax.set_ylim(-dp.vspacing, dp.vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.set_xlabel("time (s)") - if dp.add_legend: - ax.legend(loc="upper right") - - elif dp.mode == "map": - assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' - assert len(dp.clims) == 1 - clim = list(dp.clims.values())[0] - extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) - im = ax.imshow( - dp.list_traces[0].T, interpolation="nearest", origin="lower", aspect="auto", extent=extent, cmap=dp.cmap - ) - - im.set_clim(*clim) - - if dp.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if dp.show_channel_ids: - ax.set_yticks(np.linspace(min_y, max_y, n) + (max_y - min_y) / n * 0.5) - channel_labels = np.array([str(chan_id) for chan_id in dp.channel_ids]) - ax.set_yticklabels(channel_labels) - else: - ax.get_yaxis().set_visible(False) - - -TimeseriesPlotter.register(TimeseriesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_depths.py b/src/spikeinterface/widgets/matplotlib/unit_depths.py deleted file mode 100644 index aa16ff3578..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_depths.py +++ /dev/null @@ -1,22 +0,0 @@ -from ..base import to_attr -from ..unit_depths import UnitDepthsWidget -from .base_mpl import MplPlotter - - -class UnitDepthsPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - self.make_mpl_figure(**backend_kwargs) - - ax = self.ax - size = dp.num_spikes / max(dp.num_spikes) * 120 - ax.scatter(dp.unit_amplitudes, dp.unit_depths, color=dp.colors, s=size) - - ax.set_aspect(3) - ax.set_xlabel("amplitude") - ax.set_ylabel("depth [um]") - ax.set_xlim(0, max(dp.unit_amplitudes) * 1.2) - - -UnitDepthsPlotter.register(UnitDepthsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_locations.py b/src/spikeinterface/widgets/matplotlib/unit_locations.py deleted file mode 100644 index 6f084c0aec..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_locations.py +++ /dev/null @@ -1,95 +0,0 @@ -from probeinterface import ProbeGroup -from probeinterface.plotting import plot_probe - -import numpy as np -from spikeinterface.core import waveform_extractor - -from ..base import to_attr -from ..unit_locations import UnitLocationsWidget -from .base_mpl import MplPlotter - -from matplotlib.patches import Ellipse -from matplotlib.lines import Line2D - - -class UnitLocationsPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - self.make_mpl_figure(**backend_kwargs) - - unit_locations = dp.unit_locations - - probegroup = ProbeGroup.from_dict(dp.probegroup_dict) - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if dp.with_channel_ids: - text_on_contact = dp.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=self.ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - self.ax.set_title("") - - # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) - width = height = 10 - ellipse_kwargs = dict(width=width, height=height, lw=2) - - if dp.plot_all_units: - unit_colors = {} - unit_ids = dp.all_unit_ids - for unit in dp.all_unit_ids: - if unit not in dp.unit_ids: - unit_colors[unit] = "gray" - else: - unit_colors[unit] = dp.unit_colors[unit] - else: - unit_ids = dp.unit_ids - unit_colors = dp.unit_colors - labels = dp.unit_ids - - patches = [ - Ellipse( - (unit_locations[unit]), - color=unit_colors[unit], - zorder=5 if unit in dp.unit_ids else 3, - alpha=0.9 if unit in dp.unit_ids else 0.5, - **ellipse_kwargs, - ) - for i, unit in enumerate(unit_ids) - ] - for p in patches: - self.ax.add_patch(p) - handles = [ - Line2D([0], [0], ls="", marker="o", markersize=5, markeredgewidth=2, color=unit_colors[unit]) - for unit in dp.unit_ids - ] - - if dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - handles, labels, loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - if dp.hide_axis: - self.ax.axis("off") - - -UnitLocationsPlotter.register(UnitLocationsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_summary.py b/src/spikeinterface/widgets/matplotlib/unit_summary.py deleted file mode 100644 index 5327afa25e..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_summary.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_summary import UnitSummaryWidget -from .base_mpl import MplPlotter - - -from .unit_locations import UnitLocationsPlotter -from .amplitudes import AmplitudesPlotter -from .unit_waveforms import UnitWaveformPlotter -from .unit_waveforms_density_map import UnitWaveformDensityMapPlotter - -from .autocorrelograms import AutoCorrelogramsPlotter - - -class UnitSummaryPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - # force the figure without axes - if "figsize" not in backend_kwargs: - backend_kwargs["figsize"] = (18, 7) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - backend_kwargs["num_axes"] = 0 - backend_kwargs["ax"] = None - backend_kwargs["axes"] = None - - self.make_mpl_figure(**backend_kwargs) - - # and use custum grid spec - fig = self.figure - nrows = 2 - ncols = 3 - if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: - ncols += 1 - if dp.plot_data_amplitudes is not None: - nrows += 1 - gs = fig.add_gridspec(nrows, ncols) - - if dp.plot_data_unit_locations is not None: - ax1 = fig.add_subplot(gs[:2, 0]) - UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) - x, y = dp.unit_location[0], dp.unit_location[1] - ax1.set_xlim(x - 80, x + 80) - ax1.set_ylim(y - 250, y + 250) - ax1.set_xticks([]) - ax1.set_xlabel(None) - ax1.set_ylabel(None) - - ax2 = fig.add_subplot(gs[:2, 1]) - UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) - ax2.set_title(None) - - ax3 = fig.add_subplot(gs[:2, 2]) - UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) - ax3.set_ylabel(None) - - if dp.plot_data_acc is not None: - ax4 = fig.add_subplot(gs[:2, 3]) - AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) - ax4.set_title(None) - ax4.set_yticks([]) - - if dp.plot_data_amplitudes is not None: - ax5 = fig.add_subplot(gs[2, :3]) - ax6 = fig.add_subplot(gs[2, 3]) - axes = np.array([ax5, ax6]) - AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) - - fig.suptitle(f"unit_id: {dp.unit_id}") - - -UnitSummaryPlotter.register(UnitSummaryWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_templates.py b/src/spikeinterface/widgets/matplotlib/unit_templates.py deleted file mode 100644 index c1ce085bf2..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_templates.py +++ /dev/null @@ -1,9 +0,0 @@ -from ..unit_templates import UnitTemplatesWidget -from .unit_waveforms import UnitWaveformPlotter - - -class UnitTemplatesPlotter(UnitWaveformPlotter): - pass - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_waveforms.py b/src/spikeinterface/widgets/matplotlib/unit_waveforms.py deleted file mode 100644 index f499954918..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_waveforms.py +++ /dev/null @@ -1,95 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_waveforms import UnitWaveformsWidget -from .base_mpl import MplPlotter - - -class UnitWaveformPlotter(MplPlotter): - def __init__(self) -> None: - self.legend = None - - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None: - assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" - elif backend_kwargs["ax"] is not None: - assert dp.same_axis, "If 'same_axis' is not used, provide as many 'axes' as neurons" - else: - if dp.same_axis: - backend_kwargs["num_axes"] = 1 - backend_kwargs["ncols"] = None - else: - backend_kwargs["num_axes"] = len(dp.unit_ids) - backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) - - self.make_mpl_figure(**backend_kwargs) - - for i, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[i] - color = dp.unit_colors[unit_id] - - chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] - xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() - - # plot waveforms - if dp.plot_waveforms: - wfs = dp.wfs_by_ids[unit_id] - if dp.unit_selected_waveforms is not None: - wfs = wfs[dp.unit_selected_waveforms[unit_id]] - elif dp.max_spikes_per_unit is not None: - if len(wfs) > dp.max_spikes_per_unit: - random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] - wfs = wfs[random_idxs] - wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T - - if dp.x_offset_units: - # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x - else: - xvec = xvectors_flat - - ax.plot(xvec, wfs_flat, lw=dp.lw_waveforms, alpha=dp.alpha_waveforms, color=color) - - if not dp.plot_templates: - ax.get_lines()[-1].set_label(f"{unit_id}") - - # plot template - if dp.plot_templates: - template = dp.templates[i, :, :][:, chan_inds] * dp.y_scale + dp.y_offset[:, chan_inds] - - if dp.x_offset_units: - # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x - else: - xvec = xvectors_flat - - ax.plot( - xvec, template.T.flatten(), lw=dp.lw_templates, alpha=dp.alpha_templates, color=color, label=unit_id - ) - - template_label = dp.unit_ids[i] - if dp.set_title: - ax.set_title(f"template {template_label}") - - # plot channels - if dp.plot_channels: - # TODO enhance this - ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") - - if dp.same_axis and dp.plot_legend: - if self.legend is not None: - self.legend.remove() - self.legend = self.figure.legend( - loc="upper center", bbox_to_anchor=(0.5, 1.0), ncol=5, fancybox=True, shadow=True - ) - - -UnitWaveformPlotter.register(UnitWaveformsWidget) diff --git a/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py b/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py deleted file mode 100644 index ff9c1ec91b..0000000000 --- a/src/spikeinterface/widgets/matplotlib/unit_waveforms_density_map.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..unit_waveforms_density_map import UnitWaveformDensityMapWidget -from .base_mpl import MplPlotter - - -class UnitWaveformDensityMapPlotter(MplPlotter): - def do_plot(self, data_plot, **backend_kwargs): - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - self.make_mpl_figure(**backend_kwargs) - else: - if dp.same_axis: - num_axes = 1 - else: - num_axes = len(dp.unit_ids) - backend_kwargs["ncols"] = 1 - backend_kwargs["num_axes"] = num_axes - self.make_mpl_figure(**backend_kwargs) - - if dp.same_axis: - ax = self.ax - hist2d = dp.all_hist2d - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), - cmap="hot", - ) - else: - for unit_index, unit_id in enumerate(dp.unit_ids): - hist2d = dp.all_hist2d[unit_id] - ax = self.axes.flatten()[unit_index] - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), - cmap="hot", - ) - - for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[unit_index] - color = dp.unit_colors[unit_id] - ax.plot(dp.templates_flat[unit_id], color=color, lw=1) - - # final cosmetics - for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - if unit_index != 0: - continue - else: - ax = self.axes.flatten()[unit_index] - chan_inds = dp.channel_inds[unit_id] - for i, chan_ind in enumerate(chan_inds): - if i != 0: - ax.axvline(i * dp.template_width, color="w", lw=3) - channel_id = dp.channel_ids[chan_ind] - x = i * dp.template_width + dp.template_width // 2 - y = (dp.bin_max + dp.bin_min) / 2.0 - ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - - ax.set_xticks([]) - ax.set_ylabel(f"unit_id {unit_id}") - - -UnitWaveformDensityMapPlotter.register(UnitWaveformDensityMapWidget) diff --git a/src/spikeinterface/widgets/sortingview/__init__.py b/src/spikeinterface/widgets/sortingview/__init__.py deleted file mode 100644 index 5663f95078..0000000000 --- a/src/spikeinterface/widgets/sortingview/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .quality_metrics import QualityMetricsPlotter -from .sorting_summary import SortingSummaryPlotter -from .spike_locations import SortingviewPlotter -from .template_metrics import TemplateMetricsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .timeseries import TimeseriesPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter diff --git a/src/spikeinterface/widgets/sortingview/amplitudes.py b/src/spikeinterface/widgets/sortingview/amplitudes.py deleted file mode 100644 index 8676ccd994..0000000000 --- a/src/spikeinterface/widgets/sortingview/amplitudes.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np - -from ..base import to_attr -from ..amplitudes import AmplitudesWidget -from .base_sortingview import SortingviewPlotter - - -class AmplitudesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Amplitudes" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - sa_items = [ - vv.SpikeAmplitudesItem( - unit_id=u, - spike_times_sec=dp.spiketrains[u].astype("float32"), - spike_amplitudes=dp.amplitudes[u].astype("float32"), - ) - for u in unit_ids - ] - - v_spike_amplitudes = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector - ) - - self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) - return v_spike_amplitudes - - -AmplitudesPlotter.register(AmplitudesWidget) diff --git a/src/spikeinterface/widgets/sortingview/autocorrelograms.py b/src/spikeinterface/widgets/sortingview/autocorrelograms.py deleted file mode 100644 index 345f8c2bdf..0000000000 --- a/src/spikeinterface/widgets/sortingview/autocorrelograms.py +++ /dev/null @@ -1,34 +0,0 @@ -from ..base import to_attr -from ..autocorrelograms import AutoCorrelogramsWidget -from .base_sortingview import SortingviewPlotter - - -class AutoCorrelogramsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Auto Correlograms" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - unit_ids = self.make_serializable(dp.unit_ids) - - ac_items = [] - for i in range(len(unit_ids)): - for j in range(i, len(unit_ids)): - if i == j: - ac_items.append( - vv.AutocorrelogramItem( - unit_id=unit_ids[i], - bin_edges_sec=(dp.bins / 1000.0).astype("float32"), - bin_counts=dp.correlograms[i, j].astype("int32"), - ) - ) - - v_autocorrelograms = vv.Autocorrelograms(autocorrelograms=ac_items) - - self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) - return v_autocorrelograms - - -AutoCorrelogramsPlotter.register(AutoCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/sortingview/base_sortingview.py b/src/spikeinterface/widgets/sortingview/base_sortingview.py deleted file mode 100644 index c42da0fba3..0000000000 --- a/src/spikeinterface/widgets/sortingview/base_sortingview.py +++ /dev/null @@ -1,103 +0,0 @@ -import numpy as np - -from ...core.core_tools import check_json -from spikeinterface.widgets.base import BackendPlotter - - -class SortingviewPlotter(BackendPlotter): - backend = "sortingview" - backend_kwargs_desc = { - "generate_url": "If True, the figurl URL is generated and printed. Default True", - "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", - "figlabel": "The figurl figure label. Default None", - "height": "The height of the sortingview View in jupyter. Default None", - } - default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - - def __init__(self): - self.view = None - self.url = None - - def make_serializable(*args): - dict_to_serialize = {int(i): a for i, a in enumerate(args[1:])} - serializable_dict = check_json(dict_to_serialize) - returns = () - for i in range(len(args) - 1): - returns += (serializable_dict[str(i)],) - if len(returns) == 1: - returns = returns[0] - return returns - - @staticmethod - def is_notebook() -> bool: - try: - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole - elif shell == "TerminalInteractiveShell": - return False # Terminal running IPython - else: - return False # Other type (?) - except NameError: - return False - - def handle_display_and_url(self, view, **backend_kwargs): - self.set_view(view) - if self.is_notebook() and backend_kwargs["display"]: - display(self.view.jupyter(height=backend_kwargs["height"])) - if backend_kwargs["generate_url"]: - figlabel = backend_kwargs.get("figlabel") - if figlabel is None: - figlabel = self.default_label - url = view.url(label=figlabel) - self.set_url(url) - print(url) - - # make view and url accessible by the plotter - def set_view(self, view): - self.view = view - - def set_url(self, url): - self.url = url - - -def generate_unit_table_view(sorting, unit_properties=None, similarity_scores=None): - import sortingview.views as vv - - if unit_properties is None: - ut_columns = [] - ut_rows = [vv.UnitsTableRow(unit_id=u, values={}) for u in sorting.unit_ids] - else: - ut_columns = [] - ut_rows = [] - values = {} - valid_unit_properties = [] - for prop_name in unit_properties: - property_values = sorting.get_property(prop_name) - # make dtype available - val0 = np.array(property_values[0]) - if val0.dtype.kind in ("i", "u"): - dtype = "int" - elif val0.dtype.kind in ("U", "S"): - dtype = "str" - elif val0.dtype.kind == "f": - dtype = "float" - elif val0.dtype.kind == "b": - dtype = "bool" - else: - print(f"Unsupported dtype {val0.dtype} for property {prop_name}. Skipping") - continue - ut_columns.append(vv.UnitsTableColumn(key=prop_name, label=prop_name, dtype=dtype)) - valid_unit_properties.append(prop_name) - - for ui, unit in enumerate(sorting.unit_ids): - for prop_name in valid_unit_properties: - property_values = sorting.get_property(prop_name) - val0 = property_values[0] - if np.isnan(property_values[ui]): - continue - values[prop_name] = property_values[ui] - ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) - - v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) - return v_units_table diff --git a/src/spikeinterface/widgets/sortingview/crosscorrelograms.py b/src/spikeinterface/widgets/sortingview/crosscorrelograms.py deleted file mode 100644 index ec9c7bb16c..0000000000 --- a/src/spikeinterface/widgets/sortingview/crosscorrelograms.py +++ /dev/null @@ -1,37 +0,0 @@ -from ..base import to_attr -from ..crosscorrelograms import CrossCorrelogramsWidget -from .base_sortingview import SortingviewPlotter - - -class CrossCorrelogramsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Cross Correlograms" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - cc_items = [] - for i in range(len(unit_ids)): - for j in range(i, len(unit_ids)): - cc_items.append( - vv.CrossCorrelogramItem( - unit_id1=unit_ids[i], - unit_id2=unit_ids[j], - bin_edges_sec=(dp.bins / 1000.0).astype("float32"), - bin_counts=dp.correlograms[i, j].astype("int32"), - ) - ) - - v_cross_correlograms = vv.CrossCorrelograms( - cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector - ) - - self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) - return v_cross_correlograms - - -CrossCorrelogramsPlotter.register(CrossCorrelogramsWidget) diff --git a/src/spikeinterface/widgets/sortingview/metrics.py b/src/spikeinterface/widgets/sortingview/metrics.py deleted file mode 100644 index d46256739e..0000000000 --- a/src/spikeinterface/widgets/sortingview/metrics.py +++ /dev/null @@ -1,61 +0,0 @@ -import numpy as np - -from ...core.core_tools import check_json -from ..base import to_attr -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class MetricsPlotter(SortingviewPlotter): - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - metrics = dp.metrics - metric_names = list(metrics.columns) - - if dp.unit_ids is None: - unit_ids = metrics.index.values - else: - unit_ids = dp.unit_ids - unit_ids = self.make_serializable(unit_ids) - - metrics_sv = [] - for col in metric_names: - dtype = metrics.iloc[0][col].dtype - metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) - metrics_sv.append(metric) - - units_m = [] - for unit_id in unit_ids: - values = check_json(metrics.loc[unit_id].to_dict()) - values_skip_nans = {} - for k, v in values.items(): - if np.isnan(v): - continue - values_skip_nans[k] = v - - units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans)) - v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv) - - if not dp.hide_unit_selector: - if dp.include_metrics_data: - # make a view of the sorting to add tmp properties - sorting_copy = dp.sorting.select_units(unit_ids=dp.sorting.unit_ids) - for col in metric_names: - if col not in sorting_copy.get_property_keys(): - sorting_copy.set_property(col, metrics[col].values) - # generate table with properties - v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names) - else: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Splitter( - direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics) - ) - else: - view = v_metrics - - self.handle_display_and_url(view, **backend_kwargs) - return view diff --git a/src/spikeinterface/widgets/sortingview/quality_metrics.py b/src/spikeinterface/widgets/sortingview/quality_metrics.py deleted file mode 100644 index 379ba158a5..0000000000 --- a/src/spikeinterface/widgets/sortingview/quality_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from .metrics import MetricsPlotter -from ..quality_metrics import QualityMetricsWidget - - -class QualityMetricsPlotter(MetricsPlotter): - default_label = "SpikeInterface - Quality Metrics" - - pass - - -QualityMetricsPlotter.register(QualityMetricsWidget) diff --git a/src/spikeinterface/widgets/sortingview/sorting_summary.py b/src/spikeinterface/widgets/sortingview/sorting_summary.py deleted file mode 100644 index bb248e1691..0000000000 --- a/src/spikeinterface/widgets/sortingview/sorting_summary.py +++ /dev/null @@ -1,86 +0,0 @@ -from ..base import to_attr -from ..sorting_summary import SortingSummaryWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - -from .amplitudes import AmplitudesPlotter -from .autocorrelograms import AutoCorrelogramsPlotter -from .crosscorrelograms import CrossCorrelogramsPlotter -from .template_similarity import TemplateSimilarityPlotter -from .unit_locations import UnitLocationsPlotter -from .unit_templates import UnitTemplatesPlotter - - -class SortingSummaryPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Sorting Summary" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - - unit_ids = self.make_serializable(dp.unit_ids) - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - amplitudes_plotter = AmplitudesPlotter() - v_spike_amplitudes = amplitudes_plotter.do_plot( - dp.amplitudes, generate_url=False, display=False, backend="sortingview" - ) - template_plotter = UnitTemplatesPlotter() - v_average_waveforms = template_plotter.do_plot( - dp.templates, generate_url=False, display=False, backend="sortingview" - ) - xcorrelograms_plotter = CrossCorrelogramsPlotter() - v_cross_correlograms = xcorrelograms_plotter.do_plot( - dp.correlograms, generate_url=False, display=False, backend="sortingview" - ) - unitlocation_plotter = UnitLocationsPlotter() - v_unit_locations = unitlocation_plotter.do_plot( - dp.unit_locations, generate_url=False, display=False, backend="sortingview" - ) - # similarity - similarity_scores = [] - for i1, u1 in enumerate(unit_ids): - for i2, u2 in enumerate(unit_ids): - similarity_scores.append( - vv.UnitSimilarityScore( - unit_id1=u1, unit_id2=u2, similarity=dp.similarity["similarity"][i1, i2].astype("float32") - ) - ) - - # unit ids - v_units_table = generate_unit_table_view( - dp.waveform_extractor.sorting, dp.unit_table_properties, similarity_scores=similarity_scores - ) - - if dp.curation: - v_curation = vv.SortingCuration2(label_choices=dp.label_choices) - v1 = vv.Splitter(direction="vertical", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_curation)) - else: - v1 = v_units_table - v2 = vv.Splitter( - direction="horizontal", - item1=vv.LayoutItem(v_unit_locations, stretch=0.2), - item2=vv.LayoutItem( - vv.Splitter( - direction="horizontal", - item1=vv.LayoutItem(v_average_waveforms), - item2=vv.LayoutItem( - vv.Splitter( - direction="vertical", - item1=vv.LayoutItem(v_spike_amplitudes), - item2=vv.LayoutItem(v_cross_correlograms), - ) - ), - ) - ), - ) - - # assemble layout - v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) - - self.handle_display_and_url(v_summary, **backend_kwargs) - return v_summary - - -SortingSummaryPlotter.register(SortingSummaryWidget) diff --git a/src/spikeinterface/widgets/sortingview/spike_locations.py b/src/spikeinterface/widgets/sortingview/spike_locations.py deleted file mode 100644 index 747c3df4e7..0000000000 --- a/src/spikeinterface/widgets/sortingview/spike_locations.py +++ /dev/null @@ -1,64 +0,0 @@ -from ..base import to_attr -from ..spike_locations import SpikeLocationsWidget, estimate_axis_lims -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class SpikeLocationsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Spike Locations" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - spike_locations = dp.spike_locations - - # ensure serializable for sortingview - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - xlims, ylims = estimate_axis_lims(spike_locations) - - unit_items = [] - for unit in unit_ids: - spike_times_sec = dp.sorting.get_unit_spike_train( - unit_id=unit, segment_index=dp.segment_index, return_times=True - ) - unit_items.append( - vv.SpikeLocationsItem( - unit_id=unit, - spike_times_sec=spike_times_sec.astype("float32"), - x_locations=spike_locations[unit]["x"].astype("float32"), - y_locations=spike_locations[unit]["y"].astype("float32"), - ) - ) - - v_spike_locations = vv.SpikeLocations( - units=unit_items, - hide_unit_selector=dp.hide_unit_selector, - x_range=xlims.astype("float32"), - y_range=ylims.astype("float32"), - channel_locations=locations, - disable_auto_rotate=True, - ) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Box( - direction="horizontal", - items=[ - vv.LayoutItem(v_units_table, max_size=150), - vv.LayoutItem(v_spike_locations), - ], - ) - else: - view = v_spike_locations - - self.set_view(view) - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -SpikeLocationsPlotter.register(SpikeLocationsWidget) diff --git a/src/spikeinterface/widgets/sortingview/template_metrics.py b/src/spikeinterface/widgets/sortingview/template_metrics.py deleted file mode 100644 index 204bb8f377..0000000000 --- a/src/spikeinterface/widgets/sortingview/template_metrics.py +++ /dev/null @@ -1,11 +0,0 @@ -from .metrics import MetricsPlotter -from ..template_metrics import TemplateMetricsWidget - - -class TemplateMetricsPlotter(MetricsPlotter): - default_label = "SpikeInterface - Template Metrics" - - pass - - -TemplateMetricsPlotter.register(TemplateMetricsWidget) diff --git a/src/spikeinterface/widgets/sortingview/template_similarity.py b/src/spikeinterface/widgets/sortingview/template_similarity.py deleted file mode 100644 index e35b8c2e34..0000000000 --- a/src/spikeinterface/widgets/sortingview/template_similarity.py +++ /dev/null @@ -1,32 +0,0 @@ -from ..base import to_attr -from ..template_similarity import TemplateSimilarityWidget -from .base_sortingview import SortingviewPlotter - - -class TemplateSimilarityPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Template Similarity" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - unit_ids = self.make_serializable(dp.unit_ids) - - # similarity - ss_items = [] - for i1, u1 in enumerate(unit_ids): - for i2, u2 in enumerate(unit_ids): - ss_items.append( - vv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=dp.similarity[i1, i2].astype("float32")) - ) - - view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -TemplateSimilarityPlotter.register(TemplateSimilarityWidget) diff --git a/src/spikeinterface/widgets/sortingview/timeseries.py b/src/spikeinterface/widgets/sortingview/timeseries.py deleted file mode 100644 index eec0e920e4..0000000000 --- a/src/spikeinterface/widgets/sortingview/timeseries.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -import warnings - -from ..base import to_attr -from ..timeseries import TimeseriesWidget -from ..utils import array_to_image -from .base_sortingview import SortingviewPlotter - - -class TimeseriesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Timeseries" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - try: - import pyvips - except ImportError: - raise ImportError("To use the timeseries in sorting view you need the pyvips package.") - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' - - if not dp.order_channel_by_depth: - warnings.warn( - "It is recommended to set 'order_channel_by_depth' to True " "when using the sortingview backend" - ) - - tiled_layers = [] - for layer_key, traces in zip(dp.layer_keys, dp.list_traces): - img = array_to_image( - traces, - clim=dp.clims[layer_key], - num_timepoints_per_row=dp.num_timepoints_per_row, - colormap=dp.cmap, - scalebar=True, - sampling_frequency=dp.recordings[layer_key].get_sampling_frequency(), - ) - - tiled_layers.append(vv.TiledImageLayer(layer_key, img)) - - view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) - - self.set_view(view_ts) - - # timeseries currently doesn't display on the jupyter backend - backend_kwargs["display"] = False - self.handle_display_and_url(view_ts, **backend_kwargs) - return view_ts - - -TimeseriesPlotter.register(TimeseriesWidget) diff --git a/src/spikeinterface/widgets/sortingview/unit_locations.py b/src/spikeinterface/widgets/sortingview/unit_locations.py deleted file mode 100644 index 368b45321f..0000000000 --- a/src/spikeinterface/widgets/sortingview/unit_locations.py +++ /dev/null @@ -1,44 +0,0 @@ -from ..base import to_attr -from ..unit_locations import UnitLocationsWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class UnitLocationsPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Unit Locations" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - dp = to_attr(data_plot) - - # ensure serializable for sortingview - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - - unit_items = [] - for unit_id in unit_ids: - unit_items.append( - vv.UnitLocationsItem( - unit_id=unit_id, x=float(dp.unit_locations[unit_id][0]), y=float(dp.unit_locations[unit_id][1]) - ) - ) - - v_unit_locations = vv.UnitLocations(units=unit_items, channel_locations=locations, disable_auto_rotate=True) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting) - - view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_unit_locations)], - ) - else: - view = v_unit_locations - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -UnitLocationsPlotter.register(UnitLocationsWidget) diff --git a/src/spikeinterface/widgets/sortingview/unit_templates.py b/src/spikeinterface/widgets/sortingview/unit_templates.py deleted file mode 100644 index 37595740fd..0000000000 --- a/src/spikeinterface/widgets/sortingview/unit_templates.py +++ /dev/null @@ -1,54 +0,0 @@ -from ..base import to_attr -from ..unit_templates import UnitTemplatesWidget -from .base_sortingview import SortingviewPlotter, generate_unit_table_view - - -class UnitTemplatesPlotter(SortingviewPlotter): - default_label = "SpikeInterface - Unit Templates" - - def do_plot(self, data_plot, **backend_kwargs): - import sortingview.views as vv - - dp = to_attr(data_plot) - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - # ensure serializable for sortingview - unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids - unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices - - unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) - - templates_dict = {} - for u_i, unit in enumerate(unit_ids): - templates_dict[unit] = {} - templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] - templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] - - aw_items = [ - vv.AverageWaveformItem( - unit_id=u, - channel_ids=list(unit_id_to_channel_ids[u]), - waveform=t["mean"].astype("float32"), - waveform_std_dev=t["std"].astype("float32"), - ) - for u, t in templates_dict.items() - ] - - locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} - v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) - - if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.waveform_extractor.sorting) - - view = vv.Box( - direction="horizontal", - items=[vv.LayoutItem(v_units_table, max_size=150), vv.LayoutItem(v_average_waveforms)], - ) - else: - view = v_average_waveforms - - self.handle_display_and_url(view, **backend_kwargs) - return view - - -UnitTemplatesPlotter.register(UnitTemplatesWidget) From 2aba94fe9cc544c467bf4790517205ec4f9e6b8a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jul 2023 10:48:47 +0200 Subject: [PATCH 121/166] Eliminate restore keys in CI and simplify installation of dev version dependencies (#1858) * eliminate restore keys and simplify installation of dev version * forgot to activate the virtual env * revert restoring keys for gin data --- .github/actions/build-test-environment/action.yml | 13 +++++-------- .github/is_spikeinterface_dev.py | 6 ------ .github/workflows/full-test-with-codecov.yml | 5 +---- .github/workflows/full-test.yml | 5 +---- 4 files changed, 7 insertions(+), 22 deletions(-) delete mode 100644 .github/is_spikeinterface_dev.py diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 29ab453a99..004fe31203 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -24,18 +24,15 @@ runs: pip install -e .[test,extractors,full] shell: bash - name: Force installation of latest dev from key-packages when running dev (not release) - id: version run: | source ${{ github.workspace }}/test_env/bin/activate - if python ./.github/is_spikeinterface_dev.py; then + spikeinterface_is_dev_version=$(python -c "import importlib.metadata; version = importlib.metadata.version('spikeinterface'); print(version.endswith('dev0'))") + if [ $spikeinterface_is_dev_version = "True" ]; then echo "Running spikeinterface dev version" - pip uninstall -y neo - pip uninstall -y probeinterface - pip install git+https://github.com/NeuralEnsemble/python-neo - pip install git+https://github.com/SpikeInterface/probeinterface - else - echo "Running tests for release" + pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo + pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface fi + echo "Running tests for release, using pyproject.toml versions of neo and probeinterface" shell: bash - name: git-annex install run: | diff --git a/.github/is_spikeinterface_dev.py b/.github/is_spikeinterface_dev.py deleted file mode 100644 index 621305af90..0000000000 --- a/.github/is_spikeinterface_dev.py +++ /dev/null @@ -1,6 +0,0 @@ -import importlib.metadata - -package_name = "spikeinterface" -version = importlib.metadata.version(package_name) -if version.endswith("dev0"): - print(True) diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 3da889d64e..a5561c2ffc 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -32,8 +32,6 @@ jobs: with: path: ${{ github.workspace }}/test_env key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml') }}-${{ steps.date.outputs.date }} - restore-keys: | - ${{ runner.os }}-venv- - name: Get ephy_testing_data current head hash # the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git id: vars @@ -48,8 +46,7 @@ jobs: with: path: ~/spikeinterface_datasets key: ${{ runner.os }}-datasets-${{ steps.vars.outputs.HASH_EPHY_DATASET }} - restore-keys: | - ${{ runner.os }}-datasets + restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - name: Shows installed packages by pip, git-annex and cached testing files diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index 633b226e57..ac5130bade 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -37,8 +37,6 @@ jobs: with: path: ${{ github.workspace }}/test_env key: ${{ runner.os }}-venv-${{ hashFiles('**/pyproject.toml') }}-${{ steps.date.outputs.date }} - restore-keys: | - ${{ runner.os }}-venv- - name: Get ephy_testing_data current head hash # the key depends on the last comit repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data.git id: vars @@ -53,8 +51,7 @@ jobs: with: path: ~/spikeinterface_datasets key: ${{ runner.os }}-datasets-${{ steps.vars.outputs.HASH_EPHY_DATASET }} - restore-keys: | - ${{ runner.os }}-datasets + restore-keys: ${{ runner.os }}-datasets - name: Install packages uses: ./.github/actions/build-test-environment - name: Shows installed packages by pip, git-annex and cached testing files From d2d5a9cdc016845c11dbdaa50e2a7e39a4275a62 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 10:55:24 +0200 Subject: [PATCH 122/166] some clean --- src/spikeinterface/widgets/__init__.py | 34 ------ .../widgets/all_amplitudes_distributions.py | 2 +- src/spikeinterface/widgets/amplitudes.py | 6 +- .../widgets/autocorrelograms.py | 4 +- src/spikeinterface/widgets/base.py | 58 +-------- .../widgets/crosscorrelograms.py | 4 +- src/spikeinterface/widgets/metrics.py | 6 +- src/spikeinterface/widgets/motion.py | 2 +- src/spikeinterface/widgets/sorting_summary.py | 2 +- src/spikeinterface/widgets/spike_locations.py | 6 +- .../widgets/spikes_on_traces.py | 4 +- .../widgets/template_similarity.py | 4 +- src/spikeinterface/widgets/timeseries.py | 6 +- src/spikeinterface/widgets/unit_depths.py | 2 +- src/spikeinterface/widgets/unit_locations.py | 6 +- src/spikeinterface/widgets/unit_summary.py | 2 +- src/spikeinterface/widgets/unit_templates.py | 2 +- src/spikeinterface/widgets/unit_waveforms.py | 4 +- .../widgets/unit_waveforms_density_map.py | 2 +- ...pywidgets_utils.py => utils_ipywidgets.py} | 0 ...atplotlib_utils.py => utils_matplotlib.py} | 0 ...tingview_utils.py => utils_sortingview.py} | 0 src/spikeinterface/widgets/widget_list.py | 113 ++++-------------- 23 files changed, 64 insertions(+), 205 deletions(-) rename src/spikeinterface/widgets/{ipywidgets_utils.py => utils_ipywidgets.py} (100%) rename src/spikeinterface/widgets/{matplotlib_utils.py => utils_matplotlib.py} (100%) rename src/spikeinterface/widgets/{sortingview_utils.py => utils_sortingview.py} (100%) diff --git a/src/spikeinterface/widgets/__init__.py b/src/spikeinterface/widgets/__init__.py index bb779ff7fb..d3066f51fa 100644 --- a/src/spikeinterface/widgets/__init__.py +++ b/src/spikeinterface/widgets/__init__.py @@ -1,37 +1,3 @@ -# check if backend are available -# try: -# import matplotlib - -# HAVE_MPL = True -# except: -# HAVE_MPL = False - -# try: -# import sortingview - -# HAVE_SV = True -# except: -# HAVE_SV = False - -# try: -# import ipywidgets - -# HAVE_IPYW = True -# except: -# HAVE_IPYW = False - - -# # theses import make the Widget.resgister() at import time -# if HAVE_MPL: -# import spikeinterface.widgets.matplotlib - -# if HAVE_SV: -# import spikeinterface.widgets.sortingview - -# if HAVE_IPYW: -# import spikeinterface.widgets.ipywidgets - -# when importing widget list backend are already registered from .widget_list import * # general functions diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index d3cca278c9..56aaa77804 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -50,7 +50,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.patches import Ellipse from matplotlib.lines import Line2D diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index a2a3ccff3b..2be71f7470 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -115,7 +115,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from probeinterface.plotting import plot_probe from matplotlib.patches import Ellipse @@ -182,7 +182,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -269,7 +269,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index e7b5014367..ecb015bee2 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -11,7 +11,7 @@ def __init__(self, *args, **kargs): def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -37,7 +37,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import make_serializable, handle_display_and_url + from .utils_sortingview import make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index 7c52e1f993..eaa151ccd9 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -49,9 +49,6 @@ def set_default_plotter_backend(backend): class BaseWidget: - # this need to be reset in the subclass - possible_backends = None - def __init__( self, data_plot=None, @@ -79,6 +76,12 @@ def __init__( if immediate_plot: self.do_plot() + # subclass must define one method per supported backend: + # def plot_matplotlib(self, data_plot, **backend_kwargs): + # def plot_ipywidgets(self, data_plot, **backend_kwargs): + # def plot_sortingview(self, data_plot, **backend_kwargs): + + @classmethod def get_possible_backends(cls): return [k for k in default_backend_kwargs if hasattr(cls, f"plot_{k}")] @@ -91,25 +94,10 @@ def check_backend(self, backend): ) return backend - # def check_backend_kwargs(self, plotter, backend, **backend_kwargs): - # plotter_kwargs = plotter.default_backend_kwargs - # for k in backend_kwargs: - # if k not in plotter_kwargs: - # raise Exception( - # f"{k} is not a valid plot argument or backend keyword argument. " - # f"Possible backend keyword arguments for {backend} are: {list(plotter_kwargs.keys())}" - # ) - def do_plot(self): - # backend = self.check_backend(backend) - func = getattr(self, f"plot_{self.backend}") func(self.data_plot, **self.backend_kwargs) - # @classmethod - # def register_backend(cls, backend_plotter): - # cls.possible_backends[backend_plotter.backend] = backend_plotter - @staticmethod def check_extensions(waveform_extractor, extensions): if isinstance(extensions, str): @@ -127,27 +115,6 @@ def check_extensions(waveform_extractor, extensions): raise Exception(error_msg) -# class BackendPlotter: -# backend = "" - -# @classmethod -# def register(cls, widget_cls): -# widget_cls.register_backend(cls) - -# def update_backend_kwargs(self, **backend_kwargs): -# backend_kwargs_ = self.default_backend_kwargs.copy() -# backend_kwargs_.update(backend_kwargs) -# return backend_kwargs_ - - -# def copy_signature(source_fct): -# def copy(target_fct): -# target_fct.__signature__ = inspect.signature(source_fct) -# return target_fct - -# return copy - - class to_attr(object): def __init__(self, d): """ @@ -164,16 +131,3 @@ def __init__(self, d): def __getattribute__(self, k): d = object.__getattribute__(self, "__d") return d[k] - - -# def define_widget_function_from_class(widget_class, name): -# @copy_signature(widget_class) -# def widget_func(*args, **kwargs): -# W = widget_class(*args, **kwargs) -# W.do_plot(W.backend, **W.backend_kwargs) -# return W.plotter - -# widget_func.__doc__ = widget_class.__doc__ -# widget_func.__name__ = name - -# return widget_func diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 4b83e61b69..5635466a2d 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -68,7 +68,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -104,7 +104,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 6551bb067e..3d5e247b93 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -81,7 +81,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) metrics = dp.metrics @@ -132,7 +132,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -228,7 +228,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 1ebbb71743..6420fe8848 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -71,7 +71,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.colors import Normalize from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 5498df9a33..9b4279d94e 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -85,7 +85,7 @@ def __init__( def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) we = dp.waveform_extractor diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 06495409cf..62feff9372 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -107,7 +107,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.lines import Line2D from probeinterface import ProbeGroup @@ -195,7 +195,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -272,7 +272,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 0aeb923f38..74fc7f7501 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -163,7 +163,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from matplotlib.patches import Ellipse from matplotlib.lines import Line2D @@ -286,7 +286,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index a6e0356db1..f43a47db62 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -65,7 +65,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -89,7 +89,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 86e886babc..7165dec12a 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -218,7 +218,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) @@ -284,7 +284,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import ( + from .utils_ipywidgets import ( check_ipywidget_backend, make_timeseries_controller, make_channel_controller, @@ -506,7 +506,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url try: import pyvips diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index faf9198c0d..9bcafb53e4 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -59,7 +59,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index 9e35f7b32c..b923374a07 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -84,7 +84,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from probeinterface.plotting import plot_probe from matplotlib.patches import Ellipse @@ -170,7 +170,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() @@ -242,7 +242,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 66f522e3ca..82e3e79fb9 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -106,7 +106,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 04b26e300f..7e9a1c21a8 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -11,7 +11,7 @@ def __init__(self, *args, **kargs): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .sortingview_utils import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 833f13881d..f82d276d92 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -167,7 +167,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure from probeinterface.plotting import plot_probe from matplotlib.patches import Ellipse @@ -260,7 +260,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .ipywidgets_utils import check_ipywidget_backend, make_unit_controller + from .utils_ipywidgets import check_ipywidget_backend, make_unit_controller check_ipywidget_backend() diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9216373d87..3320a232c6 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -159,7 +159,7 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt - from .matplotlib_utils import make_mpl_figure + from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) diff --git a/src/spikeinterface/widgets/ipywidgets_utils.py b/src/spikeinterface/widgets/utils_ipywidgets.py similarity index 100% rename from src/spikeinterface/widgets/ipywidgets_utils.py rename to src/spikeinterface/widgets/utils_ipywidgets.py diff --git a/src/spikeinterface/widgets/matplotlib_utils.py b/src/spikeinterface/widgets/utils_matplotlib.py similarity index 100% rename from src/spikeinterface/widgets/matplotlib_utils.py rename to src/spikeinterface/widgets/utils_matplotlib.py diff --git a/src/spikeinterface/widgets/sortingview_utils.py b/src/spikeinterface/widgets/utils_sortingview.py similarity index 100% rename from src/spikeinterface/widgets/sortingview_utils.py rename to src/spikeinterface/widgets/utils_sortingview.py diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index a753c78d4a..eab0345d53 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,81 +1,44 @@ -# from .base import define_widget_function_from_class from .base import backend_kwargs_desc -# basics -from .timeseries import TimeseriesWidget - -# waveform -from .unit_waveforms import UnitWaveformsWidget -from .unit_templates import UnitTemplatesWidget -from .unit_waveforms_density_map import UnitWaveformDensityMapWidget - -# isi/ccg/acg +from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget +from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget - -# peak activity - -# drift/motion - -# spikes-traces -from .spikes_on_traces import SpikesOnTracesWidget - -# PC related - -# units on probe -from .unit_locations import UnitLocationsWidget -from .spike_locations import SpikeLocationsWidget - -# unit presence - - -# comparison related - -# correlogram comparison - -# amplitudes -from .amplitudes import AmplitudesWidget -from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget - -# metrics +from .motion import MotionWidget from .quality_metrics import QualityMetricsWidget +from .sorting_summary import SortingSummaryWidget +from .spike_locations import SpikeLocationsWidget +from .spikes_on_traces import SpikesOnTracesWidget from .template_metrics import TemplateMetricsWidget - - -# motion/drift -from .motion import MotionWidget - -# similarity from .template_similarity import TemplateSimilarityWidget - - +from .timeseries import TimeseriesWidget from .unit_depths import UnitDepthsWidget - -# summary +from .unit_locations import UnitLocationsWidget from .unit_summary import UnitSummaryWidget -from .sorting_summary import SortingSummaryWidget +from .unit_templates import UnitTemplatesWidget +from .unit_waveforms_density_map import UnitWaveformDensityMapWidget +from .unit_waveforms import UnitWaveformsWidget widget_list = [ - AmplitudesWidget, AllAmplitudesDistributionsWidget, + AmplitudesWidget, AutoCorrelogramsWidget, CrossCorrelogramsWidget, + MotionWidget, QualityMetricsWidget, + SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, TemplateMetricsWidget, - MotionWidget, TemplateSimilarityWidget, TimeseriesWidget, + UnitDepthsWidget, UnitLocationsWidget, + UnitSummaryWidget, UnitTemplatesWidget, - UnitWaveformsWidget, UnitWaveformDensityMapWidget, - UnitDepthsWidget, - # summary - UnitSummaryWidget, - SortingSummaryWidget, + UnitWaveformsWidget, ] @@ -105,45 +68,21 @@ # make function for all widgets -# plot_amplitudes = define_widget_function_from_class(AmplitudesWidget, "plot_amplitudes") -# plot_all_amplitudes_distributions = define_widget_function_from_class( -# AllAmplitudesDistributionsWidget, "plot_all_amplitudes_distributions" -# ) -# plot_autocorrelograms = define_widget_function_from_class(AutoCorrelogramsWidget, "plot_autocorrelograms") -# plot_crosscorrelograms = define_widget_function_from_class(CrossCorrelogramsWidget, "plot_crosscorrelograms") -# plot_quality_metrics = define_widget_function_from_class(QualityMetricsWidget, "plot_quality_metrics") -# plot_spike_locations = define_widget_function_from_class(SpikeLocationsWidget, "plot_spike_locations") -# plot_spikes_on_traces = define_widget_function_from_class(SpikesOnTracesWidget, "plot_spikes_on_traces") -# plot_template_metrics = define_widget_function_from_class(TemplateMetricsWidget, "plot_template_metrics") -# plot_motion = define_widget_function_from_class(MotionWidget, "plot_motion") -# plot_template_similarity = define_widget_function_from_class(TemplateSimilarityWidget, "plot_template_similarity") -# plot_timeseries = define_widget_function_from_class(TimeseriesWidget, "plot_timeseries") -# plot_unit_locations = define_widget_function_from_class(UnitLocationsWidget, "plot_unit_locations") -# plot_unit_templates = define_widget_function_from_class(UnitTemplatesWidget, "plot_unit_templates") -# plot_unit_waveforms = define_widget_function_from_class(UnitWaveformsWidget, "plot_unit_waveforms") -# plot_unit_waveforms_density_map = define_widget_function_from_class( -# UnitWaveformDensityMapWidget, "plot_unit_waveforms_density_map" -# ) -# plot_unit_depths = define_widget_function_from_class(UnitDepthsWidget, "plot_unit_depths") -# plot_unit_summary = define_widget_function_from_class(UnitSummaryWidget, "plot_unit_summary") -# plot_sorting_summary = define_widget_function_from_class(SortingSummaryWidget, "plot_sorting_summary") - - -plot_amplitudes = AmplitudesWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget -plot_unit_locations = UnitLocationsWidget +plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget plot_crosscorrelograms = CrossCorrelogramsWidget +plot_motion = MotionWidget +plot_quality_metrics = QualityMetricsWidget +plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget plot_template_metrics = TemplateMetricsWidget -plot_timeseries = TimeseriesWidget -plot_quality_metrics = QualityMetricsWidget -plot_motion = MotionWidget plot_template_similarity = TemplateSimilarityWidget -plot_unit_templates = UnitTemplatesWidget -plot_unit_waveforms = UnitWaveformsWidget -plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_timeseries = TimeseriesWidget plot_unit_depths = UnitDepthsWidget +plot_unit_locations = UnitLocationsWidget plot_unit_summary = UnitSummaryWidget -plot_sorting_summary = SortingSummaryWidget +plot_unit_templates = UnitTemplatesWidget +plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget +plot_unit_waveforms = UnitWaveformsWidget From fe763622ce5d866366248b583792953911950e79 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 11:00:52 +0200 Subject: [PATCH 123/166] Update release notes --- doc/releases/0.98.2.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/releases/0.98.2.rst b/doc/releases/0.98.2.rst index d60a3e53a3..2a326d1eb1 100644 --- a/doc/releases/0.98.2.rst +++ b/doc/releases/0.98.2.rst @@ -11,3 +11,7 @@ Minor release with some bug fixes. * Fix Mearec handling of new arguments before neo release 0.13 (#1848) * Fix full tests by updating hdbscan version (#1849) * Relax numpy upper bound and update tridesclous dependency (#1850) +* Drop figurl-jupyter dependency (#1855) +* Update Tridesclous 1.6.8 (#1857) +* Eliminate restore keys in CI and simplify installation of dev version dependencies (#1858) +* Allow order_channel_by_depth to accept dimentsions as list (#1861) From 91064c4d30a185c24a33d9eeee2dbd681eab91f9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:16:57 +0200 Subject: [PATCH 124/166] More clean --- .../widgets/all_amplitudes_distributions.py | 5 +- src/spikeinterface/widgets/amplitudes.py | 26 +------ .../widgets/autocorrelograms.py | 11 +-- .../widgets/crosscorrelograms.py | 10 +-- src/spikeinterface/widgets/metrics.py | 20 +---- src/spikeinterface/widgets/motion.py | 9 --- src/spikeinterface/widgets/quality_metrics.py | 1 - src/spikeinterface/widgets/sorting_summary.py | 45 +----------- src/spikeinterface/widgets/spike_locations.py | 21 +----- .../widgets/spikes_on_traces.py | 73 ------------------- .../widgets/template_metrics.py | 2 - .../widgets/template_similarity.py | 9 +-- src/spikeinterface/widgets/timeseries.py | 26 +------ src/spikeinterface/widgets/unit_depths.py | 3 +- src/spikeinterface/widgets/unit_locations.py | 14 +--- src/spikeinterface/widgets/unit_summary.py | 60 --------------- src/spikeinterface/widgets/unit_templates.py | 8 +- src/spikeinterface/widgets/unit_waveforms.py | 23 ------ .../widgets/unit_waveforms_density_map.py | 5 -- .../widgets/utils_matplotlib.py | 8 -- .../widgets/utils_sortingview.py | 8 -- 21 files changed, 17 insertions(+), 370 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 56aaa77804..280662fd7a 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -21,8 +21,6 @@ class AllAmplitudesDistributionsWidget(BaseWidget): Dict of colors with key: unit, value: color, default None """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, unit_ids=None, unit_colors=None, backend=None, **backend_kwargs ): @@ -56,8 +54,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) ax = self.ax diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 2be71f7470..7ef6e0ff61 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -35,8 +35,6 @@ class AmplitudesWidget(BaseWidget): True includes legend in plot, default True """ - possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -116,13 +114,8 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from probeinterface.plotting import plot_probe - - from matplotlib.patches import Ellipse - from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) if backend_kwargs["axes"] is not None: axes = backend_kwargs["axes"] @@ -139,7 +132,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: backend_kwargs["num_axes"] = None - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) scatter_ax = self.axes.flatten()[0] @@ -164,7 +156,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure.tight_layout() if dp.plot_legend: - # if self.legend is not None: if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( @@ -191,7 +182,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 we = data_plot["waveform_extractor"] - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -200,7 +190,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): with plt.ioff(): output = widgets.Output() with output: - # fig = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) plt.show() @@ -220,15 +209,10 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = {"plot_histograms": plot_histograms} self.controller.update(unit_controller) - # mpl_plotter = MplAmplitudesPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) for w in self.controller.values(): - # w.observe(self.updater) w.observe(self._update_ipywidget) self.widget = widgets.AppLayout( - # center=fig.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], footer=footer center=self.figure.canvas, left_sidebar=unit_widget, pane_widths=ratios + [0], @@ -236,15 +220,12 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): - # self.fig.clear() self.figure.clear() unit_ids = self.controller["unit_ids"].value @@ -261,7 +242,6 @@ def _update_ipywidget(self, change): backend_kwargs["axes"] = None backend_kwargs["ax"] = None - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() @@ -271,10 +251,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) sa_items = [ @@ -286,10 +264,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): for u in unit_ids ] - # v_spike_amplitudes = vv.SpikeAmplitudes( self.view = vv.SpikeAmplitudes( start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector ) - # self.handle_display_and_url(v_spike_amplitudes, **backend_kwargs) - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/autocorrelograms.py b/src/spikeinterface/widgets/autocorrelograms.py index ecb015bee2..e98abbed8f 100644 --- a/src/spikeinterface/widgets/autocorrelograms.py +++ b/src/spikeinterface/widgets/autocorrelograms.py @@ -4,7 +4,7 @@ class AutoCorrelogramsWidget(CrossCorrelogramsWidget): - # possible_backends = {} + # the doc is copied form CrossCorrelogramsWidget def __init__(self, *args, **kargs): CrossCorrelogramsWidget.__init__(self, *args, **kargs) @@ -14,12 +14,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["num_axes"] = len(dp.unit_ids) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) - bins = dp.bins unit_ids = dp.unit_ids correlograms = dp.correlograms @@ -39,9 +36,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) ac_items = [] @@ -58,9 +53,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.Autocorrelograms(autocorrelograms=ac_items) - # self.handle_display_and_url(v_autocorrelograms, **backend_kwargs) - # return v_autocorrelograms - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) AutoCorrelogramsWidget.__doc__ = CrossCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/crosscorrelograms.py b/src/spikeinterface/widgets/crosscorrelograms.py index 5635466a2d..3ec3fa11b6 100644 --- a/src/spikeinterface/widgets/crosscorrelograms.py +++ b/src/spikeinterface/widgets/crosscorrelograms.py @@ -27,8 +27,6 @@ class CrossCorrelogramsWidget(BaseWidget): If given, a dictionary with unit ids as keys and colors as values, default None """ - # possible_backends = {} - def __init__( self, waveform_or_sorting_extractor: Union[WaveformExtractor, BaseSorting], @@ -71,11 +69,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["ncols"] = len(dp.unit_ids) backend_kwargs["num_axes"] = int(len(dp.unit_ids) ** 2) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) assert self.axes.ndim == 2 @@ -106,10 +102,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) cc_items = [] @@ -126,6 +120,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector) - # self.handle_display_and_url(v_cross_correlograms, **backend_kwargs) - # return v_cross_correlograms - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 3d5e247b93..9dc51f522e 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -30,8 +30,6 @@ class MetricsBaseWidget(BaseWidget): If True, metrics data are included in unit table, by default True """ - # possible_backends = {} - def __init__( self, metrics, @@ -90,13 +88,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (2 * num_metrics, 2 * num_metrics) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["num_axes"] = num_metrics**2 backend_kwargs["ncols"] = num_metrics all_unit_ids = metrics.index.values - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) assert self.axes.ndim == 2 @@ -160,11 +156,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = unit_controller - # mpl_plotter = MplMetricsPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, fig, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -175,11 +166,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -199,16 +188,13 @@ def _update_ipywidget(self, change): sizes.append(size) # here we do a trick: we just update colors - # if hasattr(self.mpl_plotter, "patches"): if hasattr(self, "patches"): - # for p in self.mpl_plotter.patches: for p in self.patches: p.set_color(colors) p.set_sizes(sizes) else: backend_kwargs = {} backend_kwargs["figure"] = self.figure - # self.mpl_plotter.do_plot(self.data_plot, **backend_kwargs) self.plot_matplotlib(self.data_plot, **backend_kwargs) if len(unit_ids) > 0: @@ -231,7 +217,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) metrics = dp.metrics metric_names = list(metrics.columns) @@ -240,7 +225,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = metrics.index.values else: unit_ids = dp.unit_ids - # unit_ids = self.make_serializable(unit_ids) unit_ids = make_serializable(unit_ids) metrics_sv = [] @@ -279,6 +263,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_metrics - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 6420fe8848..cb11bcce0c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -1,11 +1,6 @@ import numpy as np -from warnings import warn from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -from ..core.template_tools import get_template_extremum_amplitude class MotionWidget(BaseWidget): @@ -36,8 +31,6 @@ class MotionWidget(BaseWidget): The alpha of the scatter points, default 0.5 """ - # possible_backends = {} - def __init__( self, motion_info, @@ -77,12 +70,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) assert backend_kwargs["axes"] is None assert backend_kwargs["ax"] is None - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) fig = self.figure fig.clear() diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 46bcd6c07b..459a32e6f2 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -1,6 +1,5 @@ from .metrics import MetricsBaseWidget from ..core.waveform_extractor import WaveformExtractor -from ..qualitymetrics import compute_quality_metrics class QualityMetricsWidget(MetricsBaseWidget): diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 9b4279d94e..9291de2956 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -9,7 +9,7 @@ from .unit_templates import UnitTemplatesWidget -from ..core import WaveformExtractor, ChannelSparsity +from ..core import WaveformExtractor class SortingSummaryWidget(BaseWidget): @@ -55,26 +55,10 @@ def __init__( if unit_ids is None: unit_ids = sorting.get_unit_ids() - # use other widgets to generate data (except for similarity) - # template_plot_data = UnitTemplatesWidget( - # we, unit_ids=unit_ids, sparsity=sparsity, hide_unit_selector=True - # ).plot_data - # ccg_plot_data = CrossCorrelogramsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - # amps_plot_data = AmplitudesWidget( - # we, unit_ids=unit_ids, max_spikes_per_unit=max_amplitudes_per_unit, hide_unit_selector=True - # ).plot_data - # locs_plot_data = UnitLocationsWidget(we, unit_ids=unit_ids, hide_unit_selector=True).plot_data - # sim_plot_data = TemplateSimilarityWidget(we, unit_ids=unit_ids).plot_data - plot_data = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, sparsity=sparsity, - # templates=template_plot_data, - # correlograms=ccg_plot_data, - # amplitudes=amps_plot_data, - # similarity=sim_plot_data, - # unit_locations=locs_plot_data, unit_table_properties=unit_table_properties, curation=curation, label_choices=label_choices, @@ -92,28 +76,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs): unit_ids = dp.unit_ids sparsity = dp.sparsity - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - - # amplitudes_plotter = AmplitudesPlotter() - # v_spike_amplitudes = amplitudes_plotter.do_plot( - # dp.amplitudes, generate_url=False, display=False, backend="sortingview" - # ) - # template_plotter = UnitTemplatesPlotter() - # v_average_waveforms = template_plotter.do_plot( - # dp.templates, generate_url=False, display=False, backend="sortingview" - # ) - # xcorrelograms_plotter = CrossCorrelogramsPlotter() - # v_cross_correlograms = xcorrelograms_plotter.do_plot( - # dp.correlograms, generate_url=False, display=False, backend="sortingview" - # ) - # unitlocation_plotter = UnitLocationsPlotter() - # v_unit_locations = unitlocation_plotter.do_plot( - # dp.unit_locations, generate_url=False, display=False, backend="sortingview" - # ) - v_spike_amplitudes = AmplitudesWidget( we, unit_ids=unit_ids, @@ -144,7 +108,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): we, unit_ids=unit_ids, immediate_plot=False, generate_url=False, display=False, backend="sortingview" ) similarity = w.data_plot["similarity"] - print(similarity.shape) # similarity similarity_scores = [] @@ -183,10 +146,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ) # assemble layout - # v_summary = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) self.view = vv.Splitter(direction="horizontal", item1=vv.LayoutItem(v1), item2=vv.LayoutItem(v2)) - # self.handle_display_and_url(v_summary, **backend_kwargs) - # return v_summary - - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/spike_locations.py b/src/spikeinterface/widgets/spike_locations.py index 62feff9372..9771b2c0e9 100644 --- a/src/spikeinterface/widgets/spike_locations.py +++ b/src/spikeinterface/widgets/spike_locations.py @@ -1,5 +1,4 @@ import numpy as np -from typing import Union from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -114,9 +113,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from probeinterface.plotting import plot_probe dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) spike_locations = dp.spike_locations @@ -168,7 +165,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for unit in dp.unit_ids ] if dp.plot_legend: - # if self.legend is not None: if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( @@ -203,7 +199,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -226,12 +221,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = unit_controller - # mpl_plotter = MplSpikeLocationsPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -242,11 +231,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -274,12 +261,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) spike_locations = dp.spike_locations # ensure serializable for sortingview - # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) locations = {str(ch): dp.channel_locations[i_ch].astype("float32") for i_ch, ch in enumerate(channel_ids)} @@ -321,11 +306,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_spike_locations - # self.set_view(view) - - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) def estimate_axis_lims(spike_locations, quantile=0.02): diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index 74fc7f7501..ab4e629a2e 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -60,8 +60,6 @@ class SpikesOnTracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -86,29 +84,8 @@ def __init__( **backend_kwargs, ): we = waveform_extractor - # recording: BaseRecording = we.recording sorting: BaseSorting = we.sorting - # ts_widget = TimeseriesWidget( - # recording, - # segment_index, - # channel_ids, - # order_channel_by_depth, - # time_range, - # mode, - # return_scaled, - # cmap, - # show_channel_ids, - # color_groups, - # color, - # clim, - # tile_size, - # seconds_per_row, - # with_colorbar, - # backend, - # **backend_kwargs, - # ) - if unit_ids is None: unit_ids = sorting.get_unit_ids() unit_ids = unit_ids @@ -150,7 +127,6 @@ def __init__( ) plot_data = dict( - # timeseries=ts_widget.plot_data, waveform_extractor=waveform_extractor, options=options, unit_ids=unit_ids, @@ -173,14 +149,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): recording = we.recording sorting = we.sorting - # first plot time series - # tsplotter = TimeseriesPlotter() - # data_plot["timeseries"]["add_legend"] = False - # tsplotter.do_plot(dp.timeseries, **backend_kwargs) - # self.ax = tsplotter.ax - # self.axes = tsplotter.axes - # self.figure = tsplotter.figure - # first plot time series ts_widget = TimeseriesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) self.ax = ts_widget.ax @@ -189,20 +157,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.ax - # we = dp.waveform_extractor - # sorting = dp.waveform_extractor.sorting - # frame_range = dp.timeseries["frame_range"] - # segment_index = dp.timeseries["segment_index"] - # min_y = np.min(dp.timeseries["channel_locations"][:, 1]) - # max_y = np.max(dp.timeseries["channel_locations"][:, 1]) - frame_range = ts_widget.data_plot["frame_range"] segment_index = ts_widget.data_plot["segment_index"] min_y = np.min(ts_widget.data_plot["channel_locations"][:, 1]) max_y = np.max(ts_widget.data_plot["channel_locations"][:, 1]) - # n = len(dp.timeseries["channel_ids"]) - # order = dp.timeseries["order"] n = len(ts_widget.data_plot["channel_ids"]) order = ts_widget.data_plot["order"] @@ -224,7 +183,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): spike_frames_to_plot = spike_frames[spike_start:spike_end] - # if dp.timeseries["mode"] == "map": if dp.options["mode"] == "map": spike_times_to_plot = sorting.get_unit_spike_train( unit, segment_index=segment_index, return_times=True @@ -253,16 +211,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # construct waveforms label_set = False if len(spike_frames_to_plot) > 0: - # vspacing = dp.timeseries["vspacing"] - # traces = dp.timeseries["list_traces"][0] vspacing = ts_widget.data_plot["vspacing"] traces = ts_widget.data_plot["list_traces"][0] waveform_idxs = spike_frames_to_plot[:, None] + np.arange(-we.nbefore, we.nafter) - frame_range[0] - # waveform_idxs = np.clip(waveform_idxs, 0, len(dp.timeseries["times"]) - 1) waveform_idxs = np.clip(waveform_idxs, 0, len(ts_widget.data_plot["times"]) - 1) - # times = dp.timeseries["times"][waveform_idxs] times = ts_widget.data_plot["times"][waveform_idxs] # discontinuity @@ -271,7 +225,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): waveforms = traces[waveform_idxs] # [:, :, order] waveforms_r = waveforms.reshape((waveforms.shape[0] * waveforms.shape[1], waveforms.shape[2])) - # for i, chan_id in enumerate(dp.timeseries["channel_ids"]): for i, chan_id in enumerate(ts_widget.data_plot["channel_ids"]): offset = vspacing * i if chan_id in chan_ids: @@ -296,7 +249,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): we = dp.waveform_extractor ratios = [0.2, 0.8] - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs_ts = backend_kwargs.copy() backend_kwargs_ts["width_cm"] = ratios[1] * backend_kwargs_ts["width_cm"] @@ -305,46 +257,28 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - # tsplotter = TimeseriesPlotter() - # data_plot["timeseries"]["add_legend"] = False - # tsplotter.do_plot(data_plot["timeseries"], **backend_kwargs_ts) - - # ts_w = tsplotter.widget - # ts_updater = tsplotter.updater - ts_widget = TimeseriesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) self.ax = ts_widget.ax self.axes = ts_widget.axes self.figure = ts_widget.figure - # we = data_plot["waveform_extractor"] - unit_widget, unit_controller = make_unit_controller( data_plot["unit_ids"], we.unit_ids, ratios[0] * width_cm, height_cm ) self.controller = dict() - # self.controller = ts_updater.controller self.controller.update(ts_widget.controller) self.controller.update(unit_controller) - # mpl_plotter = MplSpikesOnTracesPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ts_updater, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) self.widget = widgets.AppLayout(center=ts_widget.widget, left_sidebar=unit_widget, pane_widths=ratios + [0]) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -352,19 +286,12 @@ def _update_ipywidget(self, change): unit_ids = self.controller["unit_ids"].value - # update ts - # self.ts_updater.__call__(change) - - # update data plot - # data_plot = self.data_plot.copy() data_plot = self.next_data_plot - # data_plot["timeseries"] = self.ts_updater.next_data_plot data_plot["unit_ids"] = unit_ids backend_kwargs = {} backend_kwargs["ax"] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/template_metrics.py b/src/spikeinterface/widgets/template_metrics.py index 7361757666..748babb57d 100644 --- a/src/spikeinterface/widgets/template_metrics.py +++ b/src/spikeinterface/widgets/template_metrics.py @@ -22,8 +22,6 @@ class TemplateMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index f43a47db62..69aad70b1f 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -1,5 +1,4 @@ import numpy as np -from typing import Union from .base import BaseWidget, to_attr from ..core.waveform_extractor import WaveformExtractor @@ -68,9 +67,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) im = self.ax.matshow(dp.similarity, cmap=dp.cmap) @@ -91,11 +88,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) # ensure serializable for sortingview - # unit_ids = self.make_serializable(dp.unit_ids) unit_ids = make_serializable(dp.unit_ids) # similarity @@ -108,6 +103,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): self.view = vv.UnitSimilarityMatrix(unit_ids=list(unit_ids), similarity_scores=ss_items) - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/timeseries.py index 7165dec12a..9439694639 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/timeseries.py @@ -58,8 +58,6 @@ class TimeseriesWidget(BaseWidget): The output widget """ - # possible_backends = {} - def __init__( self, recording, @@ -221,9 +219,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) ax = self.ax @@ -302,7 +298,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] ratios = [0.1, 0.8, 0.2] @@ -335,15 +330,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller.update(ch_controller) self.controller.update(scale_controller) - # mpl_plotter = MplTimeseriesPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - # for w in self.controller.values(): - # if isinstance(w, widgets.Button): - # w.on_click(self.updater) - # else: - # w.observe(self.updater) - self.recordings = data_plot["recordings"] self.return_scaled = data_plot["return_scaled"] self.list_traces = None @@ -371,7 +357,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: @@ -497,7 +482,7 @@ def _update_ipywidget(self, change): backend_kwargs = {} backend_kwargs["ax"] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) + self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.figure @@ -506,7 +491,7 @@ def _update_ipywidget(self, change): def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv - from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url + from .utils_sortingview import handle_display_and_url try: import pyvips @@ -536,17 +521,12 @@ def plot_sortingview(self, data_plot, **backend_kwargs): tiled_layers.append(vv.TiledImageLayer(layer_key, img)) - # view_ts = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) self.view = vv.TiledImage(tile_size=dp.tile_size, layers=tiled_layers) - # self.set_view(view_ts) - # timeseries currently doesn't display on the jupyter backend backend_kwargs["display"] = False - # self.handle_display_and_url(view_ts, **backend_kwargs) - # return view_ts - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) def _get_trace_list(recordings, channel_ids, time_range, segment_index, order=None, return_scaled=False): diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index 9bcafb53e4..e48f274962 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -62,8 +62,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - # self.make_mpl_figure(**backend_kwargs) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) ax = self.ax diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index b923374a07..f8ea042f84 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -121,7 +121,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.set_title("") - # color = np.array([dp.unit_colors[unit_id] for unit_id in dp.unit_ids]) width = height = 10 ellipse_kwargs = dict(width=width, height=height, lw=2) @@ -178,8 +177,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -198,12 +195,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.controller = unit_controller - # mpl_plotter = MplUnitLocationsPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, ax, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -234,7 +225,6 @@ def _update_ipywidget(self, change): backend_kwargs = {} backend_kwargs["ax"] = self.ax - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) fig = self.ax.get_figure() fig.canvas.draw() @@ -244,7 +234,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): import sortingview.views as vv from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) # ensure serializable for sortingview @@ -272,5 +261,4 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_unit_locations - # self.handle_display_and_url(view, **backend_kwargs) - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 82e3e79fb9..964b5813e6 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,5 +1,4 @@ import numpy as np -from typing import Union from .base import BaseWidget, to_attr from .utils import get_unit_colors @@ -48,58 +47,11 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(we.sorting) - # if we.is_extension("unit_locations"): - # plot_data_unit_locations = UnitLocationsWidget( - # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False - # ).plot_data - # unit_locations = waveform_extractor.load_extension("unit_locations").get_data(outputs="by_unit") - # unit_location = unit_locations[unit_id] - # else: - # plot_data_unit_locations = None - # unit_location = None - - # plot_data_waveforms = UnitWaveformsWidget( - # we, - # unit_ids=[unit_id], - # unit_colors=unit_colors, - # plot_templates=True, - # same_axis=True, - # plot_legend=False, - # sparsity=sparsity, - # ).plot_data - - # plot_data_waveform_density = UnitWaveformDensityMapWidget( - # we, unit_ids=[unit_id], unit_colors=unit_colors, use_max_channel=True, plot_templates=True, same_axis=False - # ).plot_data - - # if we.is_extension("correlograms"): - # plot_data_acc = AutoCorrelogramsWidget( - # we, - # unit_ids=[unit_id], - # unit_colors=unit_colors, - # ).plot_data - # else: - # plot_data_acc = None - - # use other widget to plot data - # if we.is_extension("spike_amplitudes"): - # plot_data_amplitudes = AmplitudesWidget( - # we, unit_ids=[unit_id], unit_colors=unit_colors, plot_legend=False, plot_histograms=True - # ).plot_data - # else: - # plot_data_amplitudes = None - plot_data = dict( we=we, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - # unit_location=unit_location, - # plot_data_unit_locations=plot_data_unit_locations, - # plot_data_waveforms=plot_data_waveforms, - # plot_data_waveform_density=plot_data_waveform_density, - # plot_data_acc=plot_data_acc, - # plot_data_amplitudes=plot_data_amplitudes, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -118,27 +70,22 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): # force the figure without axes if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (18, 7) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) backend_kwargs["num_axes"] = 0 backend_kwargs["ax"] = None backend_kwargs["axes"] = None - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) # and use custum grid spec fig = self.figure nrows = 2 ncols = 3 - # if dp.plot_data_acc is not None or dp.plot_data_amplitudes is not None: if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): ncols += 1 - # if dp.plot_data_amplitudes is not None : if we.is_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - # if dp.plot_data_unit_locations is not None: if we.is_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) @@ -148,7 +95,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_locations = we.load_extension("unit_locations").get_data(outputs="by_unit") unit_location = unit_locations[unit_id] - # x, y = dp.unit_location[0], dp.unit_location[1] x, y = unit_location[0], unit_location[1] ax1.set_xlim(x - 80, x + 80) ax1.set_ylim(y - 250, y + 250) @@ -157,7 +103,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.set_ylabel(None) ax2 = fig.add_subplot(gs[:2, 1]) - # UnitWaveformPlotter().do_plot(dp.plot_data_waveforms, ax=ax2) w = UnitWaveformsWidget( we, unit_ids=[unit_id], @@ -173,7 +118,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax2.set_title(None) ax3 = fig.add_subplot(gs[:2, 2]) - # UnitWaveformDensityMapPlotter().do_plot(dp.plot_data_waveform_density, ax=ax3) UnitWaveformDensityMapWidget( we, unit_ids=[unit_id], @@ -185,10 +129,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - # if dp.plot_data_acc is not None: if we.is_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) - # AutoCorrelogramsPlotter().do_plot(dp.plot_data_acc, ax=ax4) AutoCorrelogramsWidget( we, unit_ids=[unit_id], @@ -200,12 +142,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - # if dp.plot_data_amplitudes is not None: if we.is_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6]) - # AmplitudesPlotter().do_plot(dp.plot_data_amplitudes, axes=axes) AmplitudesWidget( we, unit_ids=[unit_id], diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 7e9a1c21a8..cf58e91aa0 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -3,7 +3,7 @@ class UnitTemplatesWidget(UnitWaveformsWidget): - # possible_backends = {} + # doc is copied from UnitWaveformsWidget def __init__(self, *args, **kargs): kargs["plot_waveforms"] = False @@ -14,13 +14,11 @@ def plot_sortingview(self, data_plot, **backend_kwargs): from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) # ensure serializable for sortingview unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices - # unit_ids, channel_ids = self.make_serializable(dp.unit_ids, dp.channel_ids) unit_ids, channel_ids = make_serializable(dp.unit_ids, dp.channel_ids) templates_dict = {} @@ -52,9 +50,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): else: self.view = v_average_waveforms - # self.handle_display_and_url(view, **backend_kwargs) - # return view - self.url = handle_display_and_url(self, self.view, **self.backend_kwargs) + self.url = handle_display_and_url(self, self.view, **backend_kwargs) UnitTemplatesWidget.__doc__ = UnitWaveformsWidget.__doc__ diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f82d276d92..e64765b44b 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -59,8 +59,6 @@ class UnitWaveformsWidget(BaseWidget): Display legend, default True """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, @@ -168,15 +166,9 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from probeinterface.plotting import plot_probe - - from matplotlib.patches import Ellipse - from matplotlib.lines import Line2D dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) - if backend_kwargs.get("axes", None) is not None: assert len(backend_kwargs["axes"]) >= len(dp.unit_ids), "Provide as many 'axes' as neurons" elif backend_kwargs.get("ax", None) is not None: @@ -189,7 +181,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): backend_kwargs["num_axes"] = len(dp.unit_ids) backend_kwargs["ncols"] = min(dp.ncols, len(dp.unit_ids)) - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) for i, unit_id in enumerate(dp.unit_ids): @@ -249,7 +240,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.scatter(dp.channel_locations[:, 0], dp.channel_locations[:, 1], color="k") if dp.same_axis and dp.plot_legend: - # if self.legend is not None: if hasattr(self, "legend") and self.legend is not None: self.legend.remove() self.legend = self.figure.legend( @@ -269,7 +259,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): cm = 1 / 2.54 self.we = we = data_plot["waveform_extractor"] - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -317,12 +306,6 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): } self.controller.update(unit_controller) - # mpl_plotter = MplUnitWaveformPlotter() - - # self.updater = PlotUpdater(data_plot, mpl_plotter, fig_wf, ax_probe, self.controller) - # for w in self.controller.values(): - # w.observe(self.updater) - for w in self.controller.values(): w.observe(self._update_ipywidget) @@ -335,11 +318,9 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - # self.updater(None) self._update_ipywidget(None) if backend_kwargs["display"]: - # self.check_backend() display(self.widget) def _update_ipywidget(self, change): @@ -369,18 +350,14 @@ def _update_ipywidget(self, change): else: backend_kwargs["figure"] = self.fig_wf - # self.mpl_plotter.do_plot(data_plot, **backend_kwargs) self.plot_matplotlib(data_plot, **backend_kwargs) if same_axis: - # self.mpl_plotter.ax.axis("equal") self.ax.axis("equal") if hide_axis: - # self.mpl_plotter.ax.axis("off") self.ax.axis("off") else: if hide_axis: for i in range(len(unit_ids)): - # ax = self.mpl_plotter.axes.flatten()[i] ax = self.axes.flatten()[i] ax.axis("off") diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 3320a232c6..e8a6868e92 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -33,8 +33,6 @@ class UnitWaveformDensityMapWidget(BaseWidget): all channel per units, default False """ - # possible_backends = {} - def __init__( self, waveform_extractor, @@ -162,10 +160,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - # backend_kwargs = self.update_backend_kwargs(**backend_kwargs) if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) else: if dp.same_axis: @@ -174,7 +170,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): num_axes = len(dp.unit_ids) backend_kwargs["ncols"] = 1 backend_kwargs["num_axes"] = num_axes - # self.make_mpl_figure(**backend_kwargs) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) if dp.same_axis: diff --git a/src/spikeinterface/widgets/utils_matplotlib.py b/src/spikeinterface/widgets/utils_matplotlib.py index fb347552b1..a9128d7b66 100644 --- a/src/spikeinterface/widgets/utils_matplotlib.py +++ b/src/spikeinterface/widgets/utils_matplotlib.py @@ -65,11 +65,3 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None, figure.suptitle(figtitle) return figure, axes, ax - - # self.figure = figure - # self.ax = ax - # axes is always a 2D array of ax - # self.axes = axes - - # if figtitle is not None: - # self.figure.suptitle(figtitle) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 764246becf..24ae481a6b 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -3,14 +3,6 @@ from ..core.core_tools import check_json -sortingview_backend_kwargs_desc = { - "generate_url": "If True, the figurl URL is generated and printed. Default True", - "display": "If True and in jupyter notebook/lab, the widget is displayed in the cell. Default True.", - "figlabel": "The figurl figure label. Default None", - "height": "The height of the sortingview View in jupyter. Default None", -} -sortingview_default_backend_kwargs = {"generate_url": True, "display": True, "figlabel": None, "height": None} - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} From 69af6b41d37b206adab8cdb5e8f198c5f2f0f9ab Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:25:51 +0200 Subject: [PATCH 125/166] plot_timeseries > plot_traces --- doc/api.rst | 2 +- doc/how_to/analyse_neuropixels.rst | 10 +++++----- doc/how_to/get_started.rst | 2 +- doc/modules/widgets.rst | 8 ++++---- examples/how_to/analyse_neuropixels.py | 10 +++++----- examples/how_to/get_started.py | 2 +- .../extractors/plot_1_read_various_formats.py | 2 +- .../widgets/plot_1_rec_gallery.py | 10 +++++----- .../extractors/tests/test_cbin_ibl_extractors.py | 2 +- .../preprocessing/tests/test_filter.py | 8 ++++---- .../preprocessing/tests/test_normalize_scale.py | 2 +- .../preprocessing/tests/test_phase_shift.py | 6 +++--- .../preprocessing/tests/test_rectify.py | 2 +- .../benchmark/benchmark_peak_localization.py | 2 +- .../widgets/_legacy_mpl_widgets/__init__.py | 2 +- .../widgets/_legacy_mpl_widgets/amplitudes.py | 6 +++--- .../widgets/_legacy_mpl_widgets/timeseries_.py | 8 ++++---- src/spikeinterface/widgets/spikes_on_traces.py | 6 +++--- src/spikeinterface/widgets/tests/test_widgets.py | 16 ++++++++-------- .../widgets/{timeseries.py => traces.py} | 10 +++++----- src/spikeinterface/widgets/widget_list.py | 13 ++++++++++--- 21 files changed, 68 insertions(+), 61 deletions(-) rename src/spikeinterface/widgets/{timeseries.py => traces.py} (98%) diff --git a/doc/api.rst b/doc/api.rst index e0a863bd9c..932c989c19 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -275,7 +275,7 @@ spikeinterface.widgets .. autofunction:: plot_spikes_on_traces .. autofunction:: plot_template_metrics .. autofunction:: plot_template_similarity - .. autofunction:: plot_timeseries + .. autofunction:: plot_traces .. autofunction:: plot_unit_depths .. autofunction:: plot_unit_locations .. autofunction:: plot_unit_summary diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 0a02a47211..31dbc7422c 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -264,7 +264,7 @@ the ipywydgets interactive ploter .. code:: python %matplotlib widget - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') Note that using this ipywydgets make possible to explore diffrents preprocessing chain wihtout to save the entire file to disk. Everything @@ -276,9 +276,9 @@ is lazy, so you can change the previsous cell (parameters, step order, # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) - si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) - si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) - si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) + si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) + si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) + si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) @@ -292,7 +292,7 @@ is lazy, so you can change the previsous cell (parameters, step order, # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) some_chans = rec.channel_ids[[100, 150, 200, ]] - si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) + si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index 02ccb872d1..0f6aa9eb3f 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -104,7 +104,7 @@ and the raster plots. .. code:: ipython3 - w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) + w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) diff --git a/doc/modules/widgets.rst b/doc/modules/widgets.rst index 9cb99ab5a1..86c541dfd0 100644 --- a/doc/modules/widgets.rst +++ b/doc/modules/widgets.rst @@ -123,7 +123,7 @@ The :code:`plot_*(..., backend="matplotlib")` functions come with the following .. code-block:: python # matplotlib backend - w = plot_timeseries(recording, backend="matplotlib") + w = plot_traces(recording, backend="matplotlib") **Output:** @@ -146,9 +146,9 @@ Each function has the following additional arguments: from spikeinterface.preprocessing import common_reference - # ipywidgets backend also supports multiple "layers" for plot_timeseries + # ipywidgets backend also supports multiple "layers" for plot_traces rec_dict = dict(filt=recording, cmr=common_reference(recording)) - w = sw.plot_timeseries(rec_dict, backend="ipywidgets") + w = sw.plot_traces(rec_dict, backend="ipywidgets") **Output:** @@ -171,7 +171,7 @@ The functions have the following additional arguments: .. code-block:: python # sortingview backend - w_ts = sw.plot_timeseries(recording, backend="ipywidgets") + w_ts = sw.plot_traces(recording, backend="ipywidgets") w_ss = sw.plot_sorting_summary(recording, backend="sortingview") diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 9b9048cd0d..637120a591 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -82,7 +82,7 @@ # # ```python # # %matplotlib widget -# si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') +# si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='ipywidgets') # ``` # # Note that using this ipywidgets make possible to explore different preprocessing chains without saving the entire file to disk. @@ -94,9 +94,9 @@ # here we use a static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) -si.plot_timeseries(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) -si.plot_timeseries(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) -si.plot_timeseries(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) +si.plot_traces(rec1, backend='matplotlib', clim=(-50, 50), ax=axs[0]) +si.plot_traces(rec4, backend='matplotlib', clim=(-50, 50), ax=axs[1]) +si.plot_traces(rec, backend='matplotlib', clim=(-50, 50), ax=axs[2]) for i, label in enumerate(('filter', 'cmr', 'final')): axs[i].set_title(label) # - @@ -104,7 +104,7 @@ # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) some_chans = rec.channel_ids[[100, 150, 200, ]] -si.plot_timeseries({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) +si.plot_traces({'filter':rec1, 'cmr': rec4}, backend='matplotlib', mode='line', ax=ax, channel_ids=some_chans) # ### Should we save the preprocessed data to a binary file? diff --git a/examples/how_to/get_started.py b/examples/how_to/get_started.py index 266d585de9..7860c605af 100644 --- a/examples/how_to/get_started.py +++ b/examples/how_to/get_started.py @@ -92,7 +92,7 @@ # # Let's use the `spikeinterface.widgets` module to visualize the traces and the raster plots. -w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) +w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) # This is how you retrieve info from a `BaseRecording`... diff --git a/examples/modules_gallery/extractors/plot_1_read_various_formats.py b/examples/modules_gallery/extractors/plot_1_read_various_formats.py index 98988a1746..ed0ba34396 100644 --- a/examples/modules_gallery/extractors/plot_1_read_various_formats.py +++ b/examples/modules_gallery/extractors/plot_1_read_various_formats.py @@ -87,7 +87,7 @@ import spikeinterface.widgets as sw -w_ts = sw.plot_timeseries(recording, time_range=(0, 5)) +w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting, time_range=(0, 5)) plt.show() diff --git a/examples/modules_gallery/widgets/plot_1_rec_gallery.py b/examples/modules_gallery/widgets/plot_1_rec_gallery.py index d3d4792535..1544bbfc54 100644 --- a/examples/modules_gallery/widgets/plot_1_rec_gallery.py +++ b/examples/modules_gallery/widgets/plot_1_rec_gallery.py @@ -15,22 +15,22 @@ recording, sorting = se.toy_example(duration=10, num_channels=4, seed=0, num_segments=1) ############################################################################## -# plot_timeseries() +# plot_traces() # ~~~~~~~~~~~~~~~~~ -w_ts = sw.plot_timeseries(recording) +w_ts = sw.plot_traces(recording) ############################################################################## # We can select time range -w_ts1 = sw.plot_timeseries(recording, time_range=(5, 8)) +w_ts1 = sw.plot_traces(recording, time_range=(5, 8)) ############################################################################## # We can color with groups recording2 = recording.clone() recording2.set_channel_groups(channel_ids=recording.get_channel_ids(), groups=[0, 0, 1, 1]) -w_ts2 = sw.plot_timeseries(recording2, time_range=(5, 8), color_groups=True) +w_ts2 = sw.plot_traces(recording2, time_range=(5, 8), color_groups=True) ############################################################################## # **Note**: each function returns a widget object, which allows to access the figure and axis. @@ -41,7 +41,7 @@ ############################################################################## # We can also use the 'map' mode useful for high channel count -w_ts = sw.plot_timeseries(recording, mode='map', time_range=(5, 8), +w_ts = sw.plot_traces(recording, mode='map', time_range=(5, 8), show_channel_ids=True, order_channel_by_depth=True) ############################################################################## diff --git a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py index 3c4e23f14a..2e364b13bc 100644 --- a/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py +++ b/src/spikeinterface/extractors/tests/test_cbin_ibl_extractors.py @@ -22,7 +22,7 @@ class CompressedBinaryIblExtractorTest(RecordingCommonTestSuite, unittest.TestCa # ~ import matplotlib.pyplot as plt # ~ import spikeinterface.widgets as sw # ~ from probeinterface.plotting import plot_probe -# ~ sw.plot_timeseries(rec) +# ~ sw.plot_traces(rec) # ~ plot_probe(rec.get_probe()) # ~ plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 5d6cc0eb16..95e5a097ff 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -105,10 +105,10 @@ def test_filter_opencl(): # rec2_cached0 = rec2.save(chunk_size=1000,verbose=False, progress_bar=True, n_jobs=4) # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries - # plot_timeseries(rec, segment_index=0) - # plot_timeseries(rec_filtered, segment_index=0) - # plot_timeseries(rec2_cached0, segment_index=0) + # from spikeinterface.widgets import plot_traces + # plot_traces(rec, segment_index=0) + # plot_traces(rec_filtered, segment_index=0) + # plot_traces(rec2_cached0, segment_index=0) # plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 45db8440b9..b62a73a8cb 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -30,7 +30,7 @@ def test_normalize_by_quantile(): rec2.save(verbose=False) # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries + # from spikeinterface.widgets import plot_traces # fig, ax = plt.subplots() # ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g') # ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r') diff --git a/src/spikeinterface/preprocessing/tests/test_phase_shift.py b/src/spikeinterface/preprocessing/tests/test_phase_shift.py index 41293b6c25..b1ccc433b3 100644 --- a/src/spikeinterface/preprocessing/tests/test_phase_shift.py +++ b/src/spikeinterface/preprocessing/tests/test_phase_shift.py @@ -104,9 +104,9 @@ def test_phase_shift(): # ~ import matplotlib.pyplot as plt # ~ import spikeinterface.full as si - # ~ si.plot_timeseries(rec, segment_index=0, time_range=[0, 10]) - # ~ si.plot_timeseries(rec2, segment_index=0, time_range=[0, 10]) - # ~ si.plot_timeseries(rec3, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec2, segment_index=0, time_range=[0, 10]) + # ~ si.plot_traces(rec3, segment_index=0, time_range=[0, 10]) # ~ plt.show() diff --git a/src/spikeinterface/preprocessing/tests/test_rectify.py b/src/spikeinterface/preprocessing/tests/test_rectify.py index d4f58d3cc3..cca41ebf7d 100644 --- a/src/spikeinterface/preprocessing/tests/test_rectify.py +++ b/src/spikeinterface/preprocessing/tests/test_rectify.py @@ -27,7 +27,7 @@ def test_rectify(): assert traces.shape[1] == 1 # import matplotlib.pyplot as plt - # from spikeinterface.widgets import plot_timeseries + # from spikeinterface.widgets import plot_traces # fig, ax = plt.subplots() # ax.plot(rec.get_traces(segment_index=0)[:, 0], color='g') # ax.plot(rec2.get_traces(segment_index=0)[:, 0], color='r') diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py index b5ad24a5b3..e1a8ade22b 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_localization.py @@ -455,7 +455,7 @@ def plot_figure_1(benchmark, mode="average", cell_ind="auto"): ) print(benchmark.recording) - # si.plot_timeseries(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) + # si.plot_traces(benchmark.recording, mode='line', time_range=(times[0]-0.01, times[0] + 0.1), channel_ids=benchmark.recording.channel_ids[:20], ax=axs[0, 1]) # axs[0, 1].set_ylabel('Neurons') # si.plot_spikes_on_traces(benchmark.waveforms, unit_ids=[unit_id], time_range=(times[0]-0.01, times[0] + 0.1), unit_colors={unit_id : 'r'}, ax=axs[0, 1], diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 06f68a754e..81f2e4009b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,5 +1,5 @@ # basics -# from .timeseries import plot_timeseries, TimeseriesWidget +# from .timeseries import plot_timeseries, TracesWidget from .rasters import plot_rasters, RasterWidget from .probemap import plot_probe_map, ProbeMapWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py index 37bfab9d66..dd7c801e9c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py @@ -31,7 +31,7 @@ def plot(self): self._do_plot() -class AmplitudeTimeseriesWidget(AmplitudeBaseWidget): +class AmplitudeTracesWidget(AmplitudeBaseWidget): """ Plots waveform amplitudes distribution. @@ -130,12 +130,12 @@ def _do_plot(self): def plot_amplitudes_timeseries(*args, **kwargs): - W = AmplitudeTimeseriesWidget(*args, **kwargs) + W = AmplitudeTracesWidget(*args, **kwargs) W.plot() return W -plot_amplitudes_timeseries.__doc__ = AmplitudeTimeseriesWidget.__doc__ +plot_amplitudes_timeseries.__doc__ = AmplitudeTracesWidget.__doc__ def plot_amplitudes_distribution(*args, **kwargs): diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py index 5856549da3..ab6fa2ace5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py @@ -6,7 +6,7 @@ import scipy.spatial -class TimeseriesWidget(BaseWidget): +class TracesWidget(BaseWidget): """ Plots recording timeseries. @@ -46,7 +46,7 @@ class TimeseriesWidget(BaseWidget): Returns ------- - W: TimeseriesWidget + W: TracesWidget The output widget """ @@ -225,9 +225,9 @@ def _initialize_stats(self): def plot_timeseries(*args, **kwargs): - W = TimeseriesWidget(*args, **kwargs) + W = TracesWidget(*args, **kwargs) W.plot() return W -plot_timeseries.__doc__ = TimeseriesWidget.__doc__ +plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index ab4e629a2e..e7bcff0832 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -2,7 +2,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from .timeseries import TimeseriesWidget +from .traces import TracesWidget from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.waveform_extractor import WaveformExtractor @@ -150,7 +150,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = we.sorting # first plot time series - ts_widget = TimeseriesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) + ts_widget = TracesWidget(recording, **dp.options, backend="matplotlib", **backend_kwargs) self.ax = ts_widget.ax self.axes = ts_widget.axes self.figure = ts_widget.figure @@ -257,7 +257,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): width_cm = backend_kwargs["width_cm"] # plot timeseries - ts_widget = TimeseriesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) + ts_widget = TracesWidget(we.recording, **dp.options, backend="ipywidgets", **backend_kwargs_ts) self.ax = ts_widget.ax self.axes = ts_widget.axes self.figure = ts_widget.figure diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 610da470e8..96c6ab80eb 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -86,16 +86,16 @@ def setUpClass(cls): cls.gt_comp = sc.compare_sorter_to_ground_truth(cls.sorting, cls.sorting) - def test_plot_timeseries(self): - possible_backends = list(sw.TimeseriesWidget.get_possible_backends()) + def test_plot_traces(self): + possible_backends = list(sw.TracesWidget.get_possible_backends()) for backend in possible_backends: if ON_GITHUB and backend == "sortingview": continue if backend not in self.skip_backends: - sw.plot_timeseries( + sw.plot_traces( self.recording, mode="map", show_channel_ids=True, backend=backend, **self.backend_kwargs[backend] ) - sw.plot_timeseries( + sw.plot_traces( self.recording, mode="map", show_channel_ids=True, @@ -105,8 +105,8 @@ def test_plot_timeseries(self): ) if backend != "sortingview": - sw.plot_timeseries(self.recording, mode="auto", backend=backend, **self.backend_kwargs[backend]) - sw.plot_timeseries( + sw.plot_traces(self.recording, mode="auto", backend=backend, **self.backend_kwargs[backend]) + sw.plot_traces( self.recording, mode="line", show_channel_ids=True, @@ -114,7 +114,7 @@ def test_plot_timeseries(self): **self.backend_kwargs[backend], ) # multi layer - sw.plot_timeseries( + sw.plot_traces( {"rec0": self.recording, "rec1": scale(self.recording, gain=0.8, offset=0)}, color="r", mode="line", @@ -337,7 +337,7 @@ def test_sorting_summary(self): # mytest.test_plot_unit_waveforms_density_map() # mytest.test_plot_unit_summary() # mytest.test_plot_all_amplitudes_distributions() - # mytest.test_plot_timeseries() + # mytest.test_plot_traces() # mytest.test_plot_unit_waveforms() # mytest.test_plot_unit_templates() # mytest.test_plot_unit_templates() diff --git a/src/spikeinterface/widgets/timeseries.py b/src/spikeinterface/widgets/traces.py similarity index 98% rename from src/spikeinterface/widgets/timeseries.py rename to src/spikeinterface/widgets/traces.py index 9439694639..53f1593260 100644 --- a/src/spikeinterface/widgets/timeseries.py +++ b/src/spikeinterface/widgets/traces.py @@ -7,7 +7,7 @@ from .utils import get_some_colors, array_to_image -class TimeseriesWidget(BaseWidget): +class TracesWidget(BaseWidget): """ Plots recording timeseries. @@ -54,7 +54,7 @@ class TimeseriesWidget(BaseWidget): Returns ------- - W: TimeseriesWidget + W: TracesWidget The output widget """ @@ -90,7 +90,7 @@ def __init__( recordings = {f"rec{i}": rec for i, rec in enumerate(recording)} rec0 = recordings[0] else: - raise ValueError("plot_timeseries recording must be recording or dict or list") + raise ValueError("plot_traces recording must be recording or dict or list") layer_keys = list(recordings.keys()) @@ -256,7 +256,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.legend(loc="upper right") elif dp.mode == "map": - assert len(dp.list_traces) == 1, 'plot_timeseries with mode="map" do not support multi recording' + assert len(dp.list_traces) == 1, 'plot_traces with mode="map" do not support multi recording' assert len(dp.clims) == 1 clim = list(dp.clims.values())[0] extent = (dp.time_range[0], dp.time_range[1], min_y, max_y) @@ -501,7 +501,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) - assert dp.mode == "map", 'sortingview plot_timeseries is only mode="map"' + assert dp.mode == "map", 'sortingview plot_traces is only mode="map"' if not dp.order_channel_by_depth: warnings.warn( diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index eab0345d53..f3c640ff16 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -1,3 +1,5 @@ +import warnings + from .base import backend_kwargs_desc from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget @@ -11,7 +13,7 @@ from .spikes_on_traces import SpikesOnTracesWidget from .template_metrics import TemplateMetricsWidget from .template_similarity import TemplateSimilarityWidget -from .timeseries import TimeseriesWidget +from .traces import TracesWidget from .unit_depths import UnitDepthsWidget from .unit_locations import UnitLocationsWidget from .unit_summary import UnitSummaryWidget @@ -32,7 +34,7 @@ SpikesOnTracesWidget, TemplateMetricsWidget, TemplateSimilarityWidget, - TimeseriesWidget, + TracesWidget, UnitDepthsWidget, UnitLocationsWidget, UnitSummaryWidget, @@ -79,10 +81,15 @@ plot_spikes_on_traces = SpikesOnTracesWidget plot_template_metrics = TemplateMetricsWidget plot_template_similarity = TemplateSimilarityWidget -plot_timeseries = TimeseriesWidget +plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget plot_unit_locations = UnitLocationsWidget plot_unit_summary = UnitSummaryWidget plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget + + +def plot_timeseries(*args, **kwargs): + warnings.warn("plot_timeseries() is now plot_traces()") + return plot_traces(*args, **kwargs) From 019a5c8d59ec8b696c3c8f737b2d38c0574b6bc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 09:28:01 +0000 Subject: [PATCH 126/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/base.py | 7 +++---- src/spikeinterface/widgets/utils_sortingview.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index eaa151ccd9..dea46b8f51 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -77,10 +77,9 @@ def __init__( self.do_plot() # subclass must define one method per supported backend: - # def plot_matplotlib(self, data_plot, **backend_kwargs): - # def plot_ipywidgets(self, data_plot, **backend_kwargs): - # def plot_sortingview(self, data_plot, **backend_kwargs): - + # def plot_matplotlib(self, data_plot, **backend_kwargs): + # def plot_ipywidgets(self, data_plot, **backend_kwargs): + # def plot_sortingview(self, data_plot, **backend_kwargs): @classmethod def get_possible_backends(cls): diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 24ae481a6b..50bbab99df 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -3,7 +3,6 @@ from ..core.core_tools import check_json - def make_serializable(*args): dict_to_serialize = {int(i): a for i, a in enumerate(args)} serializable_dict = check_json(dict_to_serialize) From 9f6636b7320f01aaa9ccd81a54faabfe4f6365dd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:32:03 +0200 Subject: [PATCH 127/166] Remove unecessary legacy widgets are are alreayd ported --- .../widgets/_legacy_mpl_widgets/__init__.py | 12 - .../widgets/_legacy_mpl_widgets/amplitudes.py | 147 ------------ .../_legacy_mpl_widgets/correlograms_.py | 107 --------- .../_legacy_mpl_widgets/depthamplitude.py | 58 ----- .../_legacy_mpl_widgets/unitlocalization_.py | 109 --------- .../_legacy_mpl_widgets/unitsummary.py | 104 --------- .../unitwaveformdensitymap_.py | 199 ---------------- .../_legacy_mpl_widgets/unitwaveforms_.py | 218 ------------------ 8 files changed, 954 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 81f2e4009b..c0dcd7ea6e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -6,25 +6,15 @@ # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget -# from .correlograms import (plot_crosscorrelograms, CrossCorrelogramsWidget, -# plot_autocorrelograms, AutoCorrelogramsWidget) - # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget # waveform/PC related -# from .unitwaveforms import plot_unit_waveforms, plot_unit_templates -# from .unitwaveformdensitymap import plot_unit_waveform_density_map, UnitWaveformDensityMapWidget -# from .amplitudes import plot_amplitudes_distribution from .principalcomponent import plot_principal_component -# from .unitlocalization import plot_unit_localization, UnitLocalizationWidget - # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# from .depthamplitude import plot_units_depth_vs_amplitude - # comparison related from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget @@ -77,8 +67,6 @@ ComparisonPerformancesByTemplateSimilarity, ) -# unit summary -# from .unitsummary import plot_unit_summary, UnitSummaryWidget # unit presence from .presence import plot_presence, PresenceWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py deleted file mode 100644 index dd7c801e9c..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/amplitudes.py +++ /dev/null @@ -1,147 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget - -from ...postprocessing import compute_spike_amplitudes -from .utils import get_unit_colors - - -class AmplitudeBaseWidget(BaseWidget): - def __init__(self, waveform_extractor, unit_ids=None, compute_kwargs={}, unit_colors=None, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - - self.we = waveform_extractor - - if self.we.is_extension("spike_amplitudes"): - sac = self.we.load_extension("spike_amplitudes") - self.amplitudes = sac.get_data(outputs="by_unit") - else: - self.amplitudes = compute_spike_amplitudes(self.we, outputs="by_unit", **compute_kwargs) - - if unit_ids is None: - unit_ids = waveform_extractor.sorting.unit_ids - self.unit_ids = unit_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self.we.sorting) - self.unit_colors = unit_colors - - def plot(self): - self._do_plot() - - -class AmplitudeTracesWidget(AmplitudeBaseWidget): - """ - Plots waveform amplitudes distribution. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - - amplitudes: None or pre computed amplitudes - If None then amplitudes are recomputed - peak_sign: 'neg', 'pos', 'both' - In case of recomputing amplitudes. - - Returns - ------- - W: AmplitudeDistributionWidget - The output widget - """ - - def _do_plot(self): - sorting = self.we.sorting - # ~ unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - fs = sorting.get_sampling_frequency() - - # TODO handle segment - ax = self.ax - for i, unit_id in enumerate(self.unit_ids): - for segment_index in range(num_seg): - times = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) - times = times / fs - amps = self.amplitudes[segment_index][unit_id] - ax.scatter(times, amps, color=self.unit_colors[unit_id], s=3, alpha=1) - - if i == 0: - ax.set_title(f"segment {segment_index}") - if i == len(self.unit_ids) - 1: - ax.set_xlabel("Times [s]") - if segment_index == 0: - ax.set_ylabel(f"Amplitude") - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -class AmplitudeDistributionWidget(AmplitudeBaseWidget): - """ - Plots waveform amplitudes distribution. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - - amplitudes: None or pre computed amplitudes - If None then amplitudes are recomputed - peak_sign: 'neg', 'pos', 'both' - In case of recomputing amplitudes. - - Returns - ------- - W: AmplitudeDistributionWidget - The output widget - """ - - def _do_plot(self): - sorting = self.we.sorting - unit_ids = sorting.unit_ids - num_seg = sorting.get_num_segments() - - ax = self.ax - unit_amps = [] - for i, unit_id in enumerate(unit_ids): - amps = [] - for segment_index in range(num_seg): - amps.append(self.amplitudes[segment_index][unit_id]) - amps = np.concatenate(amps) - unit_amps.append(amps) - parts = ax.violinplot(unit_amps, showmeans=False, showmedians=False, showextrema=False) - - for i, pc in enumerate(parts["bodies"]): - color = self.unit_colors[unit_ids[i]] - pc.set_facecolor(color) - pc.set_edgecolor("black") - pc.set_alpha(1) - - ax.set_xticks(np.arange(len(unit_ids)) + 1) - ax.set_xticklabels([str(unit_id) for unit_id in unit_ids]) - - ylims = ax.get_ylim() - if np.max(ylims) < 0: - ax.set_ylim(min(ylims), 0) - if np.min(ylims) > 0: - ax.set_ylim(0, max(ylims)) - - -def plot_amplitudes_timeseries(*args, **kwargs): - W = AmplitudeTracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_amplitudes_timeseries.__doc__ = AmplitudeTracesWidget.__doc__ - - -def plot_amplitudes_distribution(*args, **kwargs): - W = AmplitudeDistributionWidget(*args, **kwargs) - W.plot() - return W - - -plot_amplitudes_distribution.__doc__ = AmplitudeDistributionWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py deleted file mode 100644 index 8e12559066..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/correlograms_.py +++ /dev/null @@ -1,107 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from .basewidget import BaseWidget - -from spikeinterface.postprocessing import compute_correlograms - - -class CrossCorrelogramsWidget(BaseWidget): - """ - Plots spike train cross-correlograms. - The diagonal is auto-correlogram. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - unit_ids: list - List of unit ids - bin_ms: float - bins duration in ms - window_ms: float - Window duration in ms - """ - - def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, axes=None): - if unit_ids is not None: - sorting = sorting.select_units(unit_ids) - self.sorting = sorting - self.compute_kwargs = dict(window_ms=window_ms, bin_ms=bin_ms) - - if axes is None: - n = len(sorting.unit_ids) - fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, sharey=True) - BaseWidget.__init__(self, None, None, axes) - - def plot(self): - correlograms, bins = compute_correlograms(self.sorting, **self.compute_kwargs) - bin_width = bins[1] - bins[0] - unit_ids = self.sorting.unit_ids - for i, unit_id1 in enumerate(unit_ids): - for j, unit_id2 in enumerate(unit_ids): - ccg = correlograms[i, j] - ax = self.axes[i, j] - if i == j: - color = "g" - else: - color = "k" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - - for i, unit_id in enumerate(unit_ids): - self.axes[0, i].set_title(str(unit_id)) - self.axes[-1, i].set_xlabel("CCG (ms)") - - -def plot_crosscorrelograms(*args, **kwargs): - W = CrossCorrelogramsWidget(*args, **kwargs) - W.plot() - return W - - -plot_crosscorrelograms.__doc__ = CrossCorrelogramsWidget.__doc__ - - -class AutoCorrelogramsWidget(BaseWidget): - """ - Plots spike train auto-correlograms. - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - unit_ids: list - List of unit ids - bin_ms: float - bins duration in ms - window_ms: float - Window duration in ms - """ - - def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, ncols=5, axes=None): - if unit_ids is not None: - sorting = sorting.select_units(unit_ids) - self.sorting = sorting - self.compute_kwargs = dict(window_ms=window_ms, bin_ms=bin_ms) - - if axes is None: - num_axes = len(sorting.unit_ids) - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - def plot(self): - correlograms, bins = compute_correlograms(self.sorting, **self.compute_kwargs) - bin_width = bins[1] - bins[0] - unit_ids = self.sorting.unit_ids - for i, unit_id in enumerate(unit_ids): - ccg = correlograms[i, i] - ax = self.axes.flatten()[i] - color = "g" - ax.bar(x=bins[:-1], height=ccg, width=bin_width, color=color, align="edge") - ax.set_title(str(unit_id)) - - -def plot_autocorrelograms(*args, **kwargs): - W = AutoCorrelogramsWidget(*args, **kwargs) - W.plot() - return W - - -plot_autocorrelograms.__doc__ = AutoCorrelogramsWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py deleted file mode 100644 index a382fee9bc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/depthamplitude.py +++ /dev/null @@ -1,58 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget - -from ...postprocessing import get_template_extremum_channel, get_template_extremum_amplitude -from .utils import get_unit_colors - - -class UnitsDepthAmplitudeWidget(BaseWidget): - def __init__(self, waveform_extractor, peak_sign="neg", depth_axis=1, unit_colors=None, figure=None, ax=None): - BaseWidget.__init__(self, figure, ax) - - self.we = waveform_extractor - self.peak_sign = peak_sign - self.depth_axis = depth_axis - if unit_colors is None: - unit_colors = get_unit_colors(self.we.sorting) - self.unit_colors = unit_colors - - def plot(self): - ax = self.ax - we = self.we - unit_ids = we.unit_ids - - channels_index = get_template_extremum_channel(we, peak_sign=self.peak_sign, outputs="index") - contact_positions = we.get_channel_locations() - - channel_depth = contact_positions[:, self.depth_axis] - unit_depth = [channel_depth[channels_index[unit_id]] for unit_id in unit_ids] - - unit_amplitude = get_template_extremum_amplitude(we, peak_sign=self.peak_sign) - unit_amplitude = np.abs([unit_amplitude[unit_id] for unit_id in unit_ids]) - - colors = [self.unit_colors[unit_id] for unit_id in unit_ids] - - num_spikes = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - for segment_index in range(we.get_num_segments()): - st = we.sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - num_spikes[i] += st.size - - size = num_spikes / max(num_spikes) * 120 - ax.scatter(unit_amplitude, unit_depth, color=colors, s=size) - - ax.set_aspect(3) - ax.set_xlabel("amplitude") - ax.set_ylabel("depth [um]") - ax.set_xlim(0, max(unit_amplitude) * 1.2) - - -def plot_units_depth_vs_amplitude(*args, **kwargs): - W = UnitsDepthAmplitudeWidget(*args, **kwargs) - W.plot() - return W - - -plot_units_depth_vs_amplitude.__doc__ = UnitsDepthAmplitudeWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py deleted file mode 100644 index a2b8beea3f..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitlocalization_.py +++ /dev/null @@ -1,109 +0,0 @@ -import numpy as np -import matplotlib.pylab as plt -from .basewidget import BaseWidget - -from probeinterface.plotting import plot_probe - -from spikeinterface.postprocessing import compute_unit_locations - -from .utils import get_unit_colors - - -class UnitLocalizationWidget(BaseWidget): - """ - Plot unit localization on probe. - - Parameters - ---------- - waveform_extractor: WaveformaExtractor - WaveformaExtractorr object - peaks: None or numpy array - Optionally can give already detected peaks - to avoid multiple computation. - method: str default 'center_of_mass' - Method used to estimate unit localization if 'unit_location' is None - method_kwargs: dict - Option for the method - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - with_channel_ids: bool False default - add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__( - self, - waveform_extractor, - method="center_of_mass", - method_kwargs={}, - unit_colors=None, - with_channel_ids=False, - figure=None, - ax=None, - ): - BaseWidget.__init__(self, figure, ax) - - self.waveform_extractor = waveform_extractor - self.method = method - self.method_kwargs = method_kwargs - - if unit_colors is None: - unit_colors = get_unit_colors(waveform_extractor.sorting) - self.unit_colors = unit_colors - - self.with_channel_ids = with_channel_ids - - def plot(self): - we = self.waveform_extractor - unit_ids = we.unit_ids - - if we.is_extension("unit_locations"): - unit_locations = we.load_extension("unit_locations").get_data() - else: - unit_locations = compute_unit_locations(we, method=self.method, **self.method_kwargs) - - ax = self.ax - probegroup = we.get_probegroup() - probe_shape_kwargs = dict(facecolor="w", edgecolor="k", lw=0.5, alpha=1.0) - contacts_kargs = dict(alpha=1.0, edgecolor="k", lw=0.5) - - for probe in probegroup.probes: - text_on_contact = None - if self.with_channel_ids: - text_on_contact = self.waveform_extractor.recording.channel_ids - - poly_contact, poly_contour = plot_probe( - probe, - ax=ax, - contacts_colors="w", - contacts_kargs=contacts_kargs, - probe_shape_kwargs=probe_shape_kwargs, - text_on_contact=text_on_contact, - ) - poly_contact.set_zorder(2) - if poly_contour is not None: - poly_contour.set_zorder(1) - - ax.set_title("") - - color = np.array([self.unit_colors[unit_id] for unit_id in unit_ids]) - loc = ax.scatter(unit_locations[:, 0], unit_locations[:, 1], marker="1", color=color, s=80, lw=3) - loc.set_zorder(3) - - -def plot_unit_localization(*args, **kwargs): - W = UnitLocalizationWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_localization.__doc__ = UnitLocalizationWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py deleted file mode 100644 index a1d0589abc..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitsummary.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from .basewidget import BaseWidget - -from .utils import get_unit_colors - -from .unitprobemap import plot_unit_probe_map -from .unitwaveformdensitymap_ import plot_unit_waveform_density_map -from .amplitudes import plot_amplitudes_timeseries -from .unitwaveforms_ import plot_unit_waveforms -from .isidistribution import plot_isi_distribution - - -class UnitSummaryWidget(BaseWidget): - """ - Plot a unit summary. - - If amplitudes are alreday computed they are displayed. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - The waveform extractor object - unit_id: into or str - The unit id to plot the summary of - unit_colors: list or None - Optional matplotlib color for the unit - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: UnitSummaryWidget - The output widget - """ - - def __init__(self, waveform_extractor, unit_id, unit_colors=None, figure=None, ax=None): - assert ax is None - # ~ assert axes is None - - if figure is None: - figure = plt.figure( - constrained_layout=False, - figsize=(15, 7), - ) - - BaseWidget.__init__(self, figure, None) - - self.waveform_extractor = waveform_extractor - self.recording = waveform_extractor.recording - self.sorting = waveform_extractor.sorting - self.unit_id = unit_id - - if unit_colors is None: - unit_colors = get_unit_colors(self.sorting) - self.unit_colors = unit_colors - - def plot(self): - we = self.waveform_extractor - - fig = self.figure - self.ax.remove() - - if we.is_extension("spike_amplitudes"): - nrows = 3 - else: - nrows = 2 - - gs = fig.add_gridspec(nrows, 6) - - ax = fig.add_subplot(gs[:, 0]) - plot_unit_probe_map(we, unit_ids=[self.unit_id], axes=[ax], colorbar=False) - ax.set_title("") - - ax = fig.add_subplot(gs[0:2, 1:3]) - plot_unit_waveforms(we, unit_ids=[self.unit_id], radius_um=60, axes=[ax], unit_colors=self.unit_colors) - ax.set_title(None) - - ax = fig.add_subplot(gs[0:2, 3:5]) - plot_unit_waveform_density_map(we, unit_ids=[self.unit_id], max_channels=1, ax=ax, same_axis=True) - ax.set_ylabel(None) - - ax = fig.add_subplot(gs[0:2, 5]) - plot_isi_distribution(we.sorting, unit_ids=[self.unit_id], axes=[ax]) - ax.set_title("") - - if we.is_extension("spike_amplitudes"): - ax = fig.add_subplot(gs[-1, 1:]) - plot_amplitudes_timeseries(we, unit_ids=[self.unit_id], ax=ax, unit_colors=self.unit_colors) - ax.set_ylabel(None) - ax.set_title(None) - - fig.suptitle(f"Unit ID: {self.unit_id}") - - -def plot_unit_summary(*args, **kwargs): - W = UnitSummaryWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_summary.__doc__ = UnitSummaryWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py deleted file mode 100644 index c5cbe07a7b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveformdensitymap_.py +++ /dev/null @@ -1,199 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget -from .utils import get_unit_colors -from ...postprocessing import get_template_channel_sparsity - - -class UnitWaveformDensityMapWidget(BaseWidget): - """ - Plots unit waveforms using heat map density. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - channel_ids: list - The channel ids to display - unit_ids: list - List of unit ids. - plot_templates: bool - If True, templates are plotted over the waveforms - max_channels : None or int - If not None only max_channels are displayed per units. - Incompatible with with `radius_um` - radius_um: None or float - If not None, all channels within a circle around the peak waveform will be displayed - Incompatible with with `max_channels` - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - same_axis: bool - If True then all density are plot on the same axis and then channels is the union - all channel per units. - set_title: bool - Create a plot title with the unit number if True. - plot_channels: bool - Plot channel locations below traces, only used if channel_locs is True - """ - - def __init__( - self, - waveform_extractor, - channel_ids=None, - unit_ids=None, - max_channels=None, - radius_um=None, - same_axis=False, - unit_colors=None, - ax=None, - axes=None, - ): - self.waveform_extractor = waveform_extractor - self.recording = waveform_extractor.recording - self.sorting = waveform_extractor.sorting - - if unit_ids is None: - unit_ids = self.sorting.get_unit_ids() - self.unit_ids = unit_ids - - if channel_ids is None: - channel_ids = self.recording.get_channel_ids() - self.channel_ids = channel_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self.sorting) - self.unit_colors = unit_colors - - if radius_um is not None: - assert max_channels is None, "radius_um and max_channels are mutually exclusive" - if max_channels is not None: - assert radius_um is None, "radius_um and max_channels are mutually exclusive" - - self.radius_um = radius_um - self.max_channels = max_channels - self.same_axis = same_axis - - if axes is None and ax is None: - if same_axis: - fig, ax = plt.subplots() - axes = None - else: - nrows = len(unit_ids) - fig, axes = plt.subplots(nrows=nrows, squeeze=False) - axes = axes[:, 0] - ax = None - BaseWidget.__init__(self, figure=None, ax=ax, axes=axes) - - def plot(self): - we = self.waveform_extractor - - # channel sparsity - if self.radius_um is not None: - channel_inds = get_template_channel_sparsity(we, method="radius", outputs="index", radius_um=self.radius_um) - elif self.max_channels is not None: - channel_inds = get_template_channel_sparsity( - we, method="best_channels", outputs="index", num_channels=self.max_channels - ) - else: - # all channels - channel_inds = {unit_id: np.arange(len(self.channel_ids)) for unit_id in self.unit_ids} - channel_inds = {unit_id: inds for unit_id, inds in channel_inds.items() if unit_id in self.unit_ids} - - if self.same_axis: - # channel union - inds = np.unique(np.concatenate([inds.tolist() for inds in channel_inds.values()])) - channel_inds = {unit_id: inds for unit_id in self.unit_ids} - - # bins - templates = we.get_all_templates(unit_ids=self.unit_ids, mode="median") - bin_min = np.min(templates) * 1.3 - bin_max = np.max(templates) * 1.3 - bin_size = (bin_max - bin_min) / 100 - bins = np.arange(bin_min, bin_max, bin_size) - - # 2d histograms - all_hist2d = None - for unit_index, unit_id in enumerate(self.unit_ids): - chan_inds = channel_inds[unit_id] - - wfs = we.get_waveforms(unit_id) - wfs = wfs[:, :, chan_inds] - - # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) - hist2d = np.zeros((wfs_flat.shape[1], bins.size)) - indexes0 = np.arange(wfs_flat.shape[1]) - - wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32") - wf_bined = wf_bined.clip(0, bins.size - 1) - for d in wf_bined: - hist2d[indexes0, d] += 1 - - if self.same_axis: - if all_hist2d is None: - all_hist2d = hist2d - else: - all_hist2d += hist2d - else: - ax = self.axes[unit_index] - im = ax.imshow( - hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], bin_min, bin_max), - cmap="hot", - ) - - if self.same_axis: - ax = self.ax - im = ax.imshow( - all_hist2d.T, - interpolation="nearest", - origin="lower", - aspect="auto", - extent=(0, hist2d.shape[0], bin_min, bin_max), - cmap="hot", - ) - - # plot median - for unit_index, unit_id in enumerate(self.unit_ids): - if self.same_axis: - ax = self.ax - else: - ax = self.axes[unit_index] - chan_inds = channel_inds[unit_id] - template = templates[unit_index, :, chan_inds] - template_flat = template.flatten() - color = self.unit_colors[unit_id] - ax.plot(template_flat, color=color, lw=1) - - # final cosmetics - for unit_index, unit_id in enumerate(self.unit_ids): - if self.same_axis: - ax = self.ax - if unit_index != 0: - continue - else: - ax = self.axes[unit_index] - chan_inds = channel_inds[unit_id] - for i, chan_ind in enumerate(chan_inds): - if i != 0: - ax.axvline(i * wfs.shape[1], color="w", lw=3) - channel_id = self.recording.channel_ids[chan_ind] - x = i * wfs.shape[1] + wfs.shape[1] // 2 - y = (bin_max + bin_min) / 2.0 - ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - - ax.set_xticks([]) - ax.set_ylabel(f"unit_id {unit_id}") - - -def plot_unit_waveform_density_map(*args, **kwargs): - W = UnitWaveformDensityMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_waveform_density_map.__doc__ = UnitWaveformDensityMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py deleted file mode 100644 index a1e28bbb82..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/unitwaveforms_.py +++ /dev/null @@ -1,218 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt - -from .basewidget import BaseWidget -from .utils import get_unit_colors -from ...postprocessing import get_template_channel_sparsity - - -class UnitWaveformsWidget(BaseWidget): - """ - Plots unit waveforms. - - Parameters - ---------- - waveform_extractor: WaveformExtractor - channel_ids: list - The channel ids to display - unit_ids: list - List of unit ids. - plot_templates: bool - If True, templates are plotted over the waveforms - radius_um: None or float - If not None, all channels within a circle around the peak waveform will be displayed - Incompatible with with `max_channels` - max_channels : None or int - If not None only max_channels are displayed per units. - Incompatible with with `radius_um` - set_title: bool - Create a plot title with the unit number if True. - plot_channels: bool - Plot channel locations below traces. - axis_equal: bool - Equal aspect ratio for x and y axis, to visualize the array geometry to scale. - lw: float - Line width for the traces. - unit_colors: None or dict - A dict key is unit_id and value is any color format handled by matplotlib. - If None, then the get_unit_colors() is internally used. - unit_selected_waveforms: None or dict - A dict key is unit_id and value is the subset of waveforms indices that should be - be displayed - show_all_channels: bool - Show the whole probe if True, or only selected channels if False - The axis to be used. If not given an axis is created - axes: list of matplotlib axes - The axes to be used for the individual plots. If not given the required axes are created. If provided, the ax - and figure parameters are ignored - """ - - def __init__( - self, - waveform_extractor, - channel_ids=None, - unit_ids=None, - plot_waveforms=True, - plot_templates=True, - plot_channels=False, - unit_colors=None, - max_channels=None, - radius_um=None, - ncols=5, - axes=None, - lw=2, - axis_equal=False, - unit_selected_waveforms=None, - set_title=True, - ): - self.waveform_extractor = waveform_extractor - self._recording = waveform_extractor.recording - self._sorting = waveform_extractor.sorting - sorting = waveform_extractor.sorting - - if unit_ids is None: - unit_ids = self._sorting.get_unit_ids() - self._unit_ids = unit_ids - if channel_ids is None: - channel_ids = self._recording.get_channel_ids() - self._channel_ids = channel_ids - - if unit_colors is None: - unit_colors = get_unit_colors(self._sorting) - self.unit_colors = unit_colors - - self.ncols = ncols - self._plot_waveforms = plot_waveforms - self._plot_templates = plot_templates - self._plot_channels = plot_channels - - if radius_um is not None: - assert max_channels is None, "radius_um and max_channels are mutually exclusive" - if max_channels is not None: - assert radius_um is None, "radius_um and max_channels are mutually exclusive" - - self.radius_um = radius_um - self.max_channels = max_channels - self.unit_selected_waveforms = unit_selected_waveforms - - # TODO - self._lw = lw - self._axis_equal = axis_equal - - self._set_title = set_title - - if axes is None: - num_axes = len(unit_ids) - else: - num_axes = None - BaseWidget.__init__(self, None, None, axes, ncols=ncols, num_axes=num_axes) - - def plot(self): - self._do_plot() - - def _do_plot(self): - we = self.waveform_extractor - unit_ids = self._unit_ids - channel_ids = self._channel_ids - - channel_locations = self._recording.get_channel_locations(channel_ids=channel_ids) - templates = we.get_all_templates(unit_ids=unit_ids) - - xvectors, y_scale, y_offset = get_waveforms_scales(we, templates, channel_locations) - - ncols = min(self.ncols, len(unit_ids)) - nrows = int(np.ceil(len(unit_ids) / ncols)) - - if self.radius_um is not None: - channel_inds = get_template_channel_sparsity(we, method="radius", outputs="index", radius_um=self.radius_um) - elif self.max_channels is not None: - channel_inds = get_template_channel_sparsity( - we, method="best_channels", outputs="index", num_channels=self.max_channels - ) - else: - # all channels - channel_inds = {unit_id: slice(None) for unit_id in unit_ids} - - for i, unit_id in enumerate(unit_ids): - ax = self.axes.flatten()[i] - color = self.unit_colors[unit_id] - - chan_inds = channel_inds[unit_id] - xvectors_flat = xvectors[:, chan_inds].T.flatten() - - # plot waveforms - if self._plot_waveforms: - wfs = we.get_waveforms(unit_id) - if self.unit_selected_waveforms is not None: - wfs = wfs[self.unit_selected_waveforms[unit_id]][:, :, chan_inds] - else: - wfs = wfs[:, :, chan_inds] - wfs = wfs * y_scale + y_offset[None, :, chan_inds] - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T - ax.plot(xvectors_flat, wfs_flat, lw=1, alpha=0.3, color=color) - - # plot template - if self._plot_templates: - template = templates[i, :, :][:, chan_inds] * y_scale + y_offset[:, chan_inds] - if self._plot_waveforms and self._plot_templates: - color = "k" - ax.plot(xvectors_flat, template.T.flatten(), lw=1, color=color) - template_label = unit_ids[i] - ax.set_title(f"template {template_label}") - - # plot channels - if self._plot_channels: - # TODO enhance this - ax.scatter(channel_locations[:, 0], channel_locations[:, 1], color="k") - - -def get_waveforms_scales(we, templates, channel_locations): - """ - Return scales and x_vector for templates plotting - """ - wf_max = np.max(templates) - wf_min = np.max(templates) - - x_chans = np.unique(channel_locations[:, 0]) - if x_chans.size > 1: - delta_x = np.min(np.diff(x_chans)) - else: - delta_x = 40.0 - - y_chans = np.unique(channel_locations[:, 1]) - if y_chans.size > 1: - delta_y = np.min(np.diff(y_chans)) - else: - delta_y = 40.0 - - m = max(np.abs(wf_max), np.abs(wf_min)) - y_scale = delta_y / m * 0.7 - - y_offset = channel_locations[:, 1][None, :] - - xvect = delta_x * (np.arange(we.nsamples) - we.nbefore) / we.nsamples * 0.7 - - xvectors = channel_locations[:, 0][None, :] + xvect[:, None] - # put nan for discontinuity - xvectors[-1, :] = np.nan - - return xvectors, y_scale, y_offset - - -def plot_unit_waveforms(*args, **kwargs): - W = UnitWaveformsWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_waveforms.__doc__ = UnitWaveformsWidget.__doc__ - - -def plot_unit_templates(*args, **kwargs): - kwargs["plot_waveforms"] = False - W = UnitWaveformsWidget(*args, **kwargs) - W.plot() - return W - - -plot_unit_templates.__doc__ = UnitWaveformsWidget.__doc__ From 21078fd23071120c37808dd050db392c15b2409b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 11:47:36 +0200 Subject: [PATCH 128/166] oups --- src/spikeinterface/core/basesorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 997b6995ae..bad007aeae 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -142,6 +142,7 @@ def get_unit_spike_train( times = self.get_times(segment_index=segment_index) return times[spike_frames] else: + segment = self._sorting_segments[segment_index] t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times From 87369605653a456f0118d96a0e360236ec57343e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 12:04:02 +0200 Subject: [PATCH 129/166] more fix --- .../tests/test_quality_metric_calculator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 4bc61768c0..1824c6df14 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -279,7 +279,7 @@ def test_recordingless(self): def test_empty_units(self): we = self.we1 empty_spike_train = np.array([], dtype="int64") - empty_sorting = NumpySorting.from_dict( + empty_sorting = NumpySorting.from_unit_dict( {100: empty_spike_train, 200: empty_spike_train, 300: empty_spike_train}, sampling_frequency=we.sampling_frequency, ) @@ -296,7 +296,9 @@ def test_empty_units(self): if __name__ == "__main__": test = QualityMetricsExtensionTest() test.setUp() - test.test_drift_metrics() - test.test_extension() + # test.test_drift_metrics() + # test.test_extension() # test.test_nn_metrics() # test.test_peak_sign() + test.test_empty_units() + From 81d231dee7d37d32ce44f009a697f48ccbdd4416 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 10:04:23 +0000 Subject: [PATCH 130/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 1824c6df14..bd792e1aac 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -301,4 +301,3 @@ def test_empty_units(self): # test.test_nn_metrics() # test.test_peak_sign() test.test_empty_units() - From 6877be96adaea29eb25040c54ffc6b244218006a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 12:05:18 +0200 Subject: [PATCH 131/166] more fix --- src/spikeinterface/curation/tests/test_curationsorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/tests/test_curationsorting.py b/src/spikeinterface/curation/tests/test_curationsorting.py index ddc57bb726..91bc21a49f 100644 --- a/src/spikeinterface/curation/tests/test_curationsorting.py +++ b/src/spikeinterface/curation/tests/test_curationsorting.py @@ -81,7 +81,7 @@ def test_curation(): ) # Test with empty sorting - empty_sorting = CurationSorting(NumpySorting.from_dict({}, parent_sort.sampling_frequency)) + empty_sorting = CurationSorting(NumpySorting.from_unit_dict({}, parent_sort.sampling_frequency)) if __name__ == "__main__": From efc5b9c2bb121ef046ceb43099b0b9fdf33b064e Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 20 Jul 2023 12:18:13 +0200 Subject: [PATCH 132/166] fixes to neuroscope --- src/spikeinterface/extractors/neoextractors/neuroscope.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index a41441b8b7..801b9c1928 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -62,9 +62,9 @@ def map_to_neo_kwargs(cls, file_path, xml_file_path=None): # binary_file is the binary file in .dat, .lfp, .eeg if xml_file_path is not None: - neo_kwargs = {"binary_file": str(file_path), "filename": str(xml_file_path)} + neo_kwargs = {"binary_file": Path(file_path), "filename": Path(xml_file_path)} else: - neo_kwargs = {"filename": str(file_path)} + neo_kwargs = {"filename": Path(file_path)} return neo_kwargs From a7c62e2f61587c0a7e1bea06280725beeecfe8c8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 12:30:22 +0200 Subject: [PATCH 133/166] More fix --- src/spikeinterface/comparison/studytools.py | 2 +- src/spikeinterface/core/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py index 08f3613bc2..79227c865f 100644 --- a/src/spikeinterface/comparison/studytools.py +++ b/src/spikeinterface/comparison/studytools.py @@ -53,7 +53,7 @@ def setup_comparison_study(study_folder, gt_dict, **job_kwargs): for rec_name, (recording, sorting_gt) in gt_dict.items(): # write recording using save with binary folder = study_folder / "ground_truth" / rec_name - sorting_gt.save(folder=folder, format="npz") + sorting_gt.save(folder=folder, format="numpy_folder") folder = study_folder / "raw_files" / rec_name recording.save(folder=folder, format="binary", **job_kwargs) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 61ba7b535c..87c0805630 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -722,7 +722,7 @@ def save(self, **kwargs) -> "BaseExtractor": Parameters ---------- kwargs: Keyword arguments for saving. - * format: "memory", "zarr", or "binary" (for recording) / "memory" or "npz" for sorting. + * format: "memory", "zarr", or "binary" (for recording) / "memory" or "numpy_folder" or "npz_folder" for sorting. In case format is not memory, the recording is saved to a folder. See format specific functions for more info (`save_to_memory()`, `save_to_folder()`, `save_to_zarr()`) * folder: if provided, the folder path where the object is saved From c6afef90a74d8fb80ed8fea697bdaf4b1fd44e7b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 13:32:29 +0200 Subject: [PATCH 134/166] Update release notes --- doc/releases/0.98.2.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/0.98.2.rst b/doc/releases/0.98.2.rst index 2a326d1eb1..dc35fef860 100644 --- a/doc/releases/0.98.2.rst +++ b/doc/releases/0.98.2.rst @@ -15,3 +15,4 @@ Minor release with some bug fixes. * Update Tridesclous 1.6.8 (#1857) * Eliminate restore keys in CI and simplify installation of dev version dependencies (#1858) * Allow order_channel_by_depth to accept dimentsions as list (#1861) +* Fixes to Neuroscope extractor before neo release 0.13 (#1863) From e3342de59b90828fa09068f7eee41e261994fcfd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 14:41:03 +0200 Subject: [PATCH 135/166] Update doc/releases/0.98.2.rst --- doc/releases/0.98.2.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/0.98.2.rst b/doc/releases/0.98.2.rst index dc35fef860..134aeba960 100644 --- a/doc/releases/0.98.2.rst +++ b/doc/releases/0.98.2.rst @@ -3,7 +3,7 @@ SpikeInterface 0.98.2 release notes ----------------------------------- -19th July 2023 +20th July 2023 Minor release with some bug fixes. From 085a99f6045dbe896a2d560c027d4327dd2a19cd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jul 2023 15:01:48 +0200 Subject: [PATCH 136/166] Fix plot_traces SV and add plot_motion to API --- doc/api.rst | 1 + src/spikeinterface/widgets/traces.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index 932c989c19..2e9fc1567a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -269,6 +269,7 @@ spikeinterface.widgets .. autofunction:: plot_amplitudes .. autofunction:: plot_autocorrelograms .. autofunction:: plot_crosscorrelograms + .. autofunction:: plot_motion .. autofunction:: plot_quality_metrics .. autofunction:: plot_sorting_summary .. autofunction:: plot_spike_locations diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 53f1593260..c9dc04811a 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -498,7 +498,6 @@ def plot_sortingview(self, data_plot, **backend_kwargs): except ImportError: raise ImportError("To use the timeseries in sorting view you need the pyvips package.") - backend_kwargs = self.update_backend_kwargs(**backend_kwargs) dp = to_attr(data_plot) assert dp.mode == "map", 'sortingview plot_traces is only mode="map"' From 88d6fdf225b920f4b8c9c25ed3b5790d2d9de725 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 20 Jul 2023 15:09:41 +0200 Subject: [PATCH 137/166] after release --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8452cd5fa5..59afcff264 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.98.2" +version = "0.99.0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, @@ -138,8 +138,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -155,8 +155,8 @@ docs = [ "hdbscan>=0.8.33", # For sorters, probably spikingcircus "numba", # For sorters, probably spikingcircus # for release we need pypi, so this needs to be commented - # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From 2cfda687bae02c55573a288719f5f6c4a2cdc863 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Thu, 20 Jul 2023 15:20:28 +0200 Subject: [PATCH 138/166] Harmonize raidus_um --- doc/how_to/analyse_neuropixels.rst | 4 +- doc/how_to/get_started.rst | 2 +- doc/how_to/handle_drift.rst | 4 +- doc/modules/motion_correction.rst | 2 +- doc/modules/sortingcomponents.rst | 4 +- examples/how_to/analyse_neuropixels.py | 4 +- .../widgets/plot_4_peaks_gallery.py | 2 +- .../postprocessing/unit_localization.py | 4 +- src/spikeinterface/preprocessing/motion.py | 12 ++--- .../sorters/internal/spyking_circus2.py | 6 +-- .../sorters/internal/tridesclous2.py | 8 +-- .../sortingcomponents/clustering/circus.py | 4 +- .../clustering/position_and_features.py | 8 +-- .../clustering/random_projections.py | 4 +- .../sortingcomponents/features_from_peaks.py | 54 +++++++++---------- .../sortingcomponents/matching/naive.py | 4 +- .../sortingcomponents/matching/tdc.py | 6 +-- .../sortingcomponents/peak_detection.py | 16 +++--- .../sortingcomponents/peak_localization.py | 30 +++++------ .../sortingcomponents/peak_pipeline.py | 6 +-- .../tests/test_features_from_peaks.py | 4 +- .../tests/test_motion_estimation.py | 2 +- .../tests/test_peak_detection.py | 8 +-- .../tests/test_waveforms/test_temporal_pca.py | 10 ++-- src/spikeinterface/sortingcomponents/tools.py | 2 +- .../waveforms/temporal_pca.py | 6 +-- 26 files changed, 108 insertions(+), 108 deletions(-) diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index 0a02a47211..1ed2c004cd 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -426,7 +426,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, local_radius_um=50., **job_kwargs) + detect_threshold=5, radius_um=50., **job_kwargs) peaks @@ -451,7 +451,7 @@ Let’s use here the ``locally_exclusive`` method for detection and the from spikeinterface.sortingcomponents.peak_localization import localize_peaks - peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) + peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index 02ccb872d1..279de5c555 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -266,7 +266,7 @@ available parameters are dictionaries and can be accessed with: 'clustering': {}, 'detection': {'detect_threshold': 5, 'peak_sign': 'neg'}, 'filtering': {'dtype': 'float32'}, - 'general': {'local_radius_um': 100, 'ms_after': 2, 'ms_before': 2}, + 'general': {'radius_um': 100, 'ms_after': 2, 'ms_before': 2}, 'job_kwargs': {}, 'localization': {}, 'matching': {}, diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index c0a27ff0a3..7ff98a666b 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -118,10 +118,10 @@ to load them later. 'peak_sign': 'neg', 'detect_threshold': 8.0, 'exclude_sweep_ms': 0.1, - 'local_radius_um': 50}, + 'radius_um': 50}, 'select_kwargs': None, 'localize_peaks_kwargs': {'method': 'grid_convolution', - 'local_radius_um': 30.0, + 'radius_um': 30.0, 'upsampling_um': 3.0, 'sigma_um': array([ 5. , 12.5, 20. ]), 'sigma_ms': 0.25, diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 6dc949625d..62c0d6b8d4 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -159,7 +159,7 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte peaks = detect_peaks(rec, method="locally_exclusive", detect_threshold=8.0, **job_kwargs) # (optional) sub-select some peaks to speed up the localization peaks = select_peaks(peaks, ...) - peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",local_radius_um=75.0, + peak_locations = localize_peaks(rec, peaks, method="monopolar_triangulation",radius_um=75.0, max_distance_um=150.0, **job_kwargs) # Step 2: motion inference diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index b4380fc587..aa62ea5b33 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -51,7 +51,7 @@ follows: peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2, - local_radius_um=100, + radius_um=100, noise_levels=None, random_chunk_kwargs={}, outputs='numpy_compact', @@ -95,7 +95,7 @@ follows: job_kwargs = dict(chunk_duration='1s', n_jobs=8, progress_bar=True) peak_locations = localize_peaks(recording, peaks, method='center_of_mass', - local_radius_um=70., ms_before=0.3, ms_after=0.6, + radius_um=70., ms_before=0.3, ms_after=0.6, **job_kwargs) diff --git a/examples/how_to/analyse_neuropixels.py b/examples/how_to/analyse_neuropixels.py index 9b9048cd0d..f3a04681e7 100644 --- a/examples/how_to/analyse_neuropixels.py +++ b/examples/how_to/analyse_neuropixels.py @@ -170,13 +170,13 @@ job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) peaks = detect_peaks(rec, method='locally_exclusive', noise_levels=noise_levels_int16, - detect_threshold=5, local_radius_um=50., **job_kwargs) + detect_threshold=5, radius_um=50., **job_kwargs) peaks # + from spikeinterface.sortingcomponents.peak_localization import localize_peaks -peak_locations = localize_peaks(rec, peaks, method='center_of_mass', local_radius_um=50., **job_kwargs) +peak_locations = localize_peaks(rec, peaks, method='center_of_mass', radius_um=50., **job_kwargs) # - # ### Check for drift diff --git a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py index df7d9dbf2c..addd87c065 100644 --- a/examples/modules_gallery/widgets/plot_4_peaks_gallery.py +++ b/examples/modules_gallery/widgets/plot_4_peaks_gallery.py @@ -30,7 +30,7 @@ peaks = detect_peaks( rec_filtred, method='locally_exclusive', peak_sign='neg', detect_threshold=6, exclude_sweep_ms=0.3, - local_radius_um=100, + radius_um=100, noise_levels=None, random_chunk_kwargs={}, chunk_memory='10M', n_jobs=1, progress_bar=True) diff --git a/src/spikeinterface/postprocessing/unit_localization.py b/src/spikeinterface/postprocessing/unit_localization.py index 9f303de6e1..740fdd234b 100644 --- a/src/spikeinterface/postprocessing/unit_localization.py +++ b/src/spikeinterface/postprocessing/unit_localization.py @@ -568,7 +568,7 @@ def enforce_decrease_shells_data(wf_data, maxchan, radial_parents, in_place=Fals def get_grid_convolution_templates_and_weights( - contact_locations, local_radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 + contact_locations, radius_um=50, upsampling_um=5, sigma_um=np.linspace(10, 50.0, 5), margin_um=50 ): x_min, x_max = contact_locations[:, 0].min(), contact_locations[:, 0].max() y_min, y_max = contact_locations[:, 1].min(), contact_locations[:, 1].max() @@ -597,7 +597,7 @@ def get_grid_convolution_templates_and_weights( # mask to get nearest template given a channel dist = sklearn.metrics.pairwise_distances(contact_locations, template_positions) - nearest_template_mask = dist < local_radius_um + nearest_template_mask = dist < radius_um weights = np.zeros((len(sigma_um), len(contact_locations), nb_templates), dtype=np.float32) for count, sigma in enumerate(sigma_um): diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 56c7e4fa05..8b0c8006d2 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -18,12 +18,12 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, ), "select_kwargs": None, "localize_peaks_kwargs": dict( method="monopolar_triangulation", - local_radius_um=75.0, + radius_um=75.0, max_distance_um=150.0, optimizer="minimize_with_log_penality", enforce_decrease=True, @@ -81,12 +81,12 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, ), "select_kwargs": None, "localize_peaks_kwargs": dict( method="center_of_mass", - local_radius_um=75.0, + radius_um=75.0, feature="ptp", ), "estimate_motion_kwargs": dict( @@ -109,12 +109,12 @@ peak_sign="neg", detect_threshold=8.0, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, ), "select_kwargs": None, "localize_peaks_kwargs": dict( method="grid_convolution", - local_radius_um=40.0, + radius_um=40.0, upsampling_um=5.0, sigma_um=np.linspace(5.0, 25.0, 5), sigma_ms=0.25, diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 24c4a7ccfc..9de2762562 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -21,7 +21,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" _default_params = { - "general": {"ms_before": 2, "ms_after": 2, "local_radius_um": 100}, + "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, "waveforms": {"max_spikes_per_unit": 200, "overwrite": True}, "filtering": {"dtype": "float32"}, "detection": {"peak_sign": "neg", "detect_threshold": 5}, @@ -75,8 +75,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## Then, we are detecting peaks with a locally_exclusive method detection_params = params["detection"].copy() detection_params.update(job_kwargs) - if "local_radius_um" not in detection_params: - detection_params["local_radius_um"] = params["general"]["local_radius_um"] + if "radius_um" not in detection_params: + detection_params["radius_um"] = params["general"]["radius_um"] if "exclude_sweep_ms" not in detection_params: detection_params["exclude_sweep_ms"] = max(params["general"]["ms_before"], params["general"]["ms_after"]) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index a812d4ce49..42f51d3a77 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -12,7 +12,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, - "general": {"ms_before": 2.5, "ms_after": 3.5, "local_radius_um": 100}, + "general": {"ms_before": 2.5, "ms_after": 3.5, "radius_um": 100}, "filtering": {"freq_min": 300, "freq_max": 8000.0}, "detection": {"peak_sign": "neg", "detect_threshold": 5, "exclude_sweep_ms": 0.4}, "hdbscan_kwargs": { @@ -68,7 +68,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # detection detection_params = params["detection"].copy() - detection_params["local_radius_um"] = params["general"]["local_radius_um"] + detection_params["radius_um"] = params["general"]["radius_um"] detection_params["noise_levels"] = noise_levels peaks = detect_peaks(recording, method="locally_exclusive", **detection_params, **job_kwargs) @@ -89,7 +89,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # localization localization_params = params["localization"].copy() - localization_params["local_radius_um"] = params["general"]["local_radius_um"] + localization_params["radius_um"] = params["general"]["radius_um"] peak_locations = localize_peaks( recording, some_peaks, method="monopolar_triangulation", **localization_params, **job_kwargs ) @@ -127,7 +127,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): matching_params["noise_levels"] = noise_levels matching_params["peak_sign"] = params["detection"]["peak_sign"] matching_params["detect_threshold"] = params["detection"]["detect_threshold"] - matching_params["local_radius_um"] = params["general"]["local_radius_um"] + matching_params["radius_um"] = params["general"]["radius_um"] # TODO: route that params # ~ 'num_closest' : 5, diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index a6185f5193..46aba7e96f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -37,7 +37,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "tmp_folder": None, - "local_radius_um": 100, + "radius_um": 100, "n_pca": 10, "max_spikes_per_unit": 200, "ms_before": 1.5, @@ -104,7 +104,7 @@ def main_function(cls, recording, peaks, params): chan_distances = get_channel_distances(recording) for main_chan in unit_inds: - (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["local_radius_um"]) + (closest_chans,) = np.nonzero(chan_distances[main_chan, :] <= params["radius_um"]) sparsity_mask[main_chan, closest_chans] = True if params["waveform_mode"] == "shared_memory": diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index 082d2dc0ba..8d21041599 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -35,7 +35,7 @@ class PositionAndFeaturesClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "local_radius_um": 100, + "radius_um": 100, "max_spikes_per_unit": 200, "selection_method": "random", "ms_before": 1.5, @@ -69,9 +69,9 @@ def main_function(cls, recording, peaks, params): features_list = [position_method, "ptp", "energy"] features_params = { - position_method: {"local_radius_um": params["local_radius_um"]}, - "ptp": {"all_channels": False, "local_radius_um": params["local_radius_um"]}, - "energy": {"local_radius_um": params["local_radius_um"]}, + position_method: {"radius_um": params["radius_um"]}, + "ptp": {"all_channels": False, "radius_um": params["radius_um"]}, + "energy": {"radius_um": params["radius_um"]}, } features_data = compute_features_from_peaks( diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 02247dd288..fcbcac097f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -34,7 +34,7 @@ class RandomProjectionClustering: "cluster_selection_method": "leaf", }, "cleaning_kwargs": {}, - "local_radius_um": 100, + "radius_um": 100, "max_spikes_per_unit": 200, "selection_method": "closest_to_centroid", "nb_projections": {"ptp": 8, "energy": 2}, @@ -106,7 +106,7 @@ def main_function(cls, recording, peaks, params): projections = np.random.randn(num_chans, d["nb_projections"][proj_type]) features_params[f"random_projections_{proj_type}"] = { - "local_radius_um": params["local_radius_um"], + "radius_um": params["radius_um"], "projections": projections, "min_values": min_values, } diff --git a/src/spikeinterface/sortingcomponents/features_from_peaks.py b/src/spikeinterface/sortingcomponents/features_from_peaks.py index c075e8e7c1..adc025e829 100644 --- a/src/spikeinterface/sortingcomponents/features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/features_from_peaks.py @@ -105,15 +105,15 @@ def compute(self, traces, peaks, waveforms): class PeakToPeakFeature(PipelineNode): def __init__( - self, recording, name="ptp_feature", return_output=True, parents=None, local_radius_um=150.0, all_channels=True + self, recording, name="ptp_feature", return_output=True, parents=None, radius_um=150.0, all_channels=True ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um self.all_channels = all_channels - self._kwargs.update(dict(local_radius_um=local_radius_um, all_channels=all_channels)) + self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -139,19 +139,19 @@ def __init__( name="ptp_lag_feature", return_output=True, parents=None, - local_radius_um=150.0, + radius_um=150.0, all_channels=True, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.all_channels = all_channels - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um, all_channels=all_channels)) + self._kwargs.update(dict(radius_um=radius_um, all_channels=all_channels)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -184,20 +184,20 @@ def __init__( return_output=True, parents=None, projections=None, - local_radius_um=150.0, + radius_um=150.0, min_values=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.projections = projections - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.min_values = min_values self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(projections=projections, local_radius_um=local_radius_um, min_values=min_values)) + self._kwargs.update(dict(projections=projections, radius_um=radius_um, min_values=min_values)) self._dtype = recording.get_dtype() @@ -230,19 +230,19 @@ def __init__( return_output=True, parents=None, projections=None, - local_radius_um=150.0, + radius_um=150.0, min_values=None, ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um self.projections = projections self.min_values = min_values - self.local_radius_um = local_radius_um - self._kwargs.update(dict(projections=projections, min_values=min_values, local_radius_um=local_radius_um)) + self.radius_um = radius_um + self._kwargs.update(dict(projections=projections, min_values=min_values, radius_um=radius_um)) self._dtype = recording.get_dtype() def get_dtype(self): @@ -267,14 +267,14 @@ def compute(self, traces, peaks, waveforms): class StdPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="std_ptp_feature", return_output=True, parents=None, local_radius_um=150.0): + def __init__(self, recording, name="std_ptp_feature", return_output=True, parents=None, radius_um=150.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) self._dtype = recording.get_dtype() @@ -292,14 +292,14 @@ def compute(self, traces, peaks, waveforms): class GlobalPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="global_ptp_feature", return_output=True, parents=None, local_radius_um=150.0): + def __init__(self, recording, name="global_ptp_feature", return_output=True, parents=None, radius_um=150.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) self._dtype = recording.get_dtype() @@ -317,14 +317,14 @@ def compute(self, traces, peaks, waveforms): class KurtosisPeakToPeakFeature(PipelineNode): - def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, parents=None, local_radius_um=150.0): + def __init__(self, recording, name="kurtosis_ptp_feature", return_output=True, parents=None, radius_um=150.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) self._dtype = recording.get_dtype() @@ -344,14 +344,14 @@ def compute(self, traces, peaks, waveforms): class EnergyFeature(PipelineNode): - def __init__(self, recording, name="energy_feature", return_output=True, parents=None, local_radius_um=50.0): + def __init__(self, recording, name="energy_feature", return_output=True, parents=None, radius_um=50.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um - self._kwargs.update(dict(local_radius_um=local_radius_um)) + self._kwargs.update(dict(radius_um=radius_um)) def get_dtype(self): return np.dtype("float32") diff --git a/src/spikeinterface/sortingcomponents/matching/naive.py b/src/spikeinterface/sortingcomponents/matching/naive.py index ba4c0e93f3..4e2625acec 100644 --- a/src/spikeinterface/sortingcomponents/matching/naive.py +++ b/src/spikeinterface/sortingcomponents/matching/naive.py @@ -35,7 +35,7 @@ class NaiveMatching(BaseTemplateMatchingEngine): "exclude_sweep_ms": 0.1, "detect_threshold": 5, "noise_levels": None, - "local_radius_um": 100, + "radius_um": 100, "random_chunk_kwargs": {}, } @@ -54,7 +54,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["local_radius_um"] + d["neighbours_mask"] = channel_distance < d["radius_um"] d["nbefore"] = we.nbefore d["nafter"] = we.nafter diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 5fbe1b94f3..7d6d707ea2 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -50,7 +50,7 @@ class TridesclousPeeler(BaseTemplateMatchingEngine): "peak_shift_ms": 0.2, "detect_threshold": 5, "noise_levels": None, - "local_radius_um": 100, + "radius_um": 100, "num_closest": 5, "sample_shift": 3, "ms_before": 0.8, @@ -103,7 +103,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): d["abs_threholds"] = d["noise_levels"] * d["detect_threshold"] channel_distance = get_channel_distances(recording) - d["neighbours_mask"] = channel_distance < d["local_radius_um"] + d["neighbours_mask"] = channel_distance < d["radius_um"] sparsity = compute_sparsity(we, method="snr", peak_sign=d["peak_sign"], threshold=d["detect_threshold"]) template_sparsity_inds = sparsity.unit_id_to_channel_indices @@ -154,7 +154,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs): # distance channel from unit distances = scipy.spatial.distance.cdist(channel_locations, unit_locations, metric="euclidean") - near_cluster_mask = distances < d["local_radius_um"] + near_cluster_mask = distances < d["radius_um"] # nearby cluster for each channel possible_clusters_by_channel = [] diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index df3374b39d..4fd7611bb7 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -504,7 +504,7 @@ class DetectPeakLocallyExclusive(PeakDetectorWrapper): params_doc = ( DetectPeakByChannel.params_doc + """ - local_radius_um: float + radius_um: float The radius to use to select neighbour channels for locally exclusive detection. """ ) @@ -516,7 +516,7 @@ def check_params( peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, noise_levels=None, random_chunk_kwargs={}, ): @@ -533,7 +533,7 @@ def check_params( ) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < local_radius_um + neighbours_mask = channel_distance < radius_um return args + (neighbours_mask,) @classmethod @@ -580,7 +580,7 @@ class DetectPeakLocallyExclusiveTorch(PeakDetectorWrapper): params_doc = ( DetectPeakByChannel.params_doc + """ - local_radius_um: float + radius_um: float The radius to use to select neighbour channels for locally exclusive detection. """ ) @@ -594,7 +594,7 @@ def check_params( exclude_sweep_ms=0.1, noise_levels=None, device=None, - local_radius_um=50, + radius_um=50, return_tensor=False, random_chunk_kwargs={}, ): @@ -615,7 +615,7 @@ def check_params( neighbour_indices_by_chan = [] num_channels = recording.get_num_channels() for chan in range(num_channels): - neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < local_radius_um)[0]) + neighbour_indices_by_chan.append(np.nonzero(channel_distance[chan] < radius_um)[0]) max_neighbs = np.max([len(neigh) for neigh in neighbour_indices_by_chan]) neighbours_idxs = num_channels * np.ones((num_channels, max_neighbs), dtype=int) for i, neigh in enumerate(neighbour_indices_by_chan): @@ -836,7 +836,7 @@ def check_params( peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1, - local_radius_um=50, + radius_um=50, noise_levels=None, random_chunk_kwargs={}, ): @@ -847,7 +847,7 @@ def check_params( abs_threholds = noise_levels * detect_threshold exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) channel_distance = get_channel_distances(recording) - neighbours_mask = channel_distance < local_radius_um + neighbours_mask = channel_distance < radius_um executor = OpenCLDetectPeakExecutor(abs_threholds, exclude_sweep_size, neighbours_mask, peak_sign) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index d1df720624..2e61f00ae7 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -101,14 +101,14 @@ def localize_peaks(recording, peaks, method="center_of_mass", ms_before=0.5, ms_ class LocalizeBase(PipelineNode): - def __init__(self, recording, return_output=True, parents=None, local_radius_um=75.0): + def __init__(self, recording, return_output=True, parents=None, radius_um=75.0): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um - self._kwargs["local_radius_um"] = local_radius_um + self.neighbours_mask = self.channel_distance < radius_um + self._kwargs["radius_um"] = radius_um def get_dtype(self): return self._dtype @@ -152,17 +152,17 @@ class LocalizeCenterOfMass(LocalizeBase): need_waveforms = True name = "center_of_mass" params_doc = """ - local_radius_um: float + radius_um: float Radius in um for channel sparsity. feature: str ['ptp', 'mean', 'energy', 'peak_voltage'] Feature to consider for computation. Default is 'ptp' """ def __init__( - self, recording, return_output=True, parents=["extract_waveforms"], local_radius_um=75.0, feature="ptp" + self, recording, return_output=True, parents=["extract_waveforms"], radius_um=75.0, feature="ptp" ): LocalizeBase.__init__( - self, recording, return_output=return_output, parents=parents, local_radius_um=local_radius_um + self, recording, return_output=return_output, parents=parents, radius_um=radius_um ) self._dtype = np.dtype(dtype_localize_by_method["center_of_mass"]) @@ -216,7 +216,7 @@ class LocalizeMonopolarTriangulation(PipelineNode): need_waveforms = False name = "monopolar_triangulation" params_doc = """ - local_radius_um: float + radius_um: float For channel sparsity. max_distance_um: float, default: 1000 Boundary for distance estimation. @@ -234,14 +234,14 @@ def __init__( recording, return_output=True, parents=["extract_waveforms"], - local_radius_um=75.0, + radius_um=75.0, max_distance_um=150.0, optimizer="minimize_with_log_penality", enforce_decrease=True, feature="ptp", ): LocalizeBase.__init__( - self, recording, return_output=return_output, parents=parents, local_radius_um=local_radius_um + self, recording, return_output=return_output, parents=parents, radius_um=radius_um ) assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" @@ -309,7 +309,7 @@ class LocalizeGridConvolution(PipelineNode): need_waveforms = True name = "grid_convolution" params_doc = """ - local_radius_um: float + radius_um: float Radius in um for channel sparsity. upsampling_um: float Upsampling resolution for the grid of templates @@ -333,7 +333,7 @@ def __init__( recording, return_output=True, parents=["extract_waveforms"], - local_radius_um=40.0, + radius_um=40.0, upsampling_um=5.0, sigma_um=np.linspace(5.0, 25.0, 5), sigma_ms=0.25, @@ -344,7 +344,7 @@ def __init__( ): PipelineNode.__init__(self, recording, return_output=return_output, parents=parents) - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.sigma_um = sigma_um self.margin_um = margin_um self.upsampling_um = upsampling_um @@ -371,7 +371,7 @@ def __init__( self.prototype = self.prototype[:, np.newaxis] self.template_positions, self.weights, self.nearest_template_mask = get_grid_convolution_templates_and_weights( - contact_locations, self.local_radius_um, self.upsampling_um, self.sigma_um, self.margin_um + contact_locations, self.radius_um, self.upsampling_um, self.sigma_um, self.margin_um ) self.weights_sparsity_mask = self.weights > self.sparsity_threshold @@ -379,7 +379,7 @@ def __init__( self._dtype = np.dtype(dtype_localize_by_method["grid_convolution"]) self._kwargs.update( dict( - local_radius_um=self.local_radius_um, + radius_um=self.radius_um, prototype=self.prototype, template_positions=self.template_positions, nearest_template_mask=self.nearest_template_mask, diff --git a/src/spikeinterface/sortingcomponents/peak_pipeline.py b/src/spikeinterface/sortingcomponents/peak_pipeline.py index 9e43fd2d78..6f0f26201f 100644 --- a/src/spikeinterface/sortingcomponents/peak_pipeline.py +++ b/src/spikeinterface/sortingcomponents/peak_pipeline.py @@ -223,7 +223,7 @@ def __init__( ms_after: float, parents: Optional[List[PipelineNode]] = None, return_output: bool = False, - local_radius_um: float = 100.0, + radius_um: float = 100.0, ): """ Extract sparse waveforms from a recording. The strategy in this specific node is to reshape the waveforms @@ -260,10 +260,10 @@ def __init__( return_output=return_output, ) - self.local_radius_um = local_radius_um + self.radius_um = radius_um self.contact_locations = recording.get_channel_locations() self.channel_distance = get_channel_distances(recording) - self.neighbours_mask = self.channel_distance < local_radius_um + self.neighbours_mask = self.channel_distance < radius_um self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1)) def get_trace_margin(self): diff --git a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py index e46d037c9e..b3b5f656cb 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py +++ b/src/spikeinterface/sortingcomponents/tests/test_features_from_peaks.py @@ -34,8 +34,8 @@ def test_features_from_peaks(): feature_params = { "amplitude": {"all_channels": False, "peak_sign": "neg"}, "ptp": {"all_channels": False}, - "center_of_mass": {"local_radius_um": 120.0}, - "energy": {"local_radius_um": 160.0}, + "center_of_mass": {"radius_um": 120.0}, + "energy": {"radius_um": 160.0}, } features = compute_features_from_peaks(recording, peaks, feature_list, feature_params=feature_params, **job_kwargs) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 9860275739..0558c16cca 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -45,7 +45,7 @@ def setup_module(): extract_dense_waveforms = ExtractDenseWaveforms(recording, ms_before=0.1, ms_after=0.3, return_output=False) pipeline_nodes = [ extract_dense_waveforms, - LocalizeCenterOfMass(recording, parents=[extract_dense_waveforms], local_radius_um=60.0), + LocalizeCenterOfMass(recording, parents=[extract_dense_waveforms], radius_um=60.0), ] peaks, peak_locations = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index 380bd67a94..f3ca8bf96d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -139,7 +139,7 @@ def peak_detector_kwargs(recording): exclude_sweep_ms=1.0, peak_sign="both", detect_threshold=5, - local_radius_um=50, + radius_um=50, ) return peak_detector_keyword_arguments @@ -194,12 +194,12 @@ def test_iterative_peak_detection_sparse(recording, job_kwargs, pca_model_folder ms_before = 1.0 ms_after = 1.0 - local_radius_um = 40 + radius_um = 40 waveform_extraction_node = ExtractSparseWaveforms( recording=recording, ms_before=ms_before, ms_after=ms_after, - local_radius_um=local_radius_um, + radius_um=radius_um, ) waveform_denoising_node = TemporalPCADenoising( @@ -368,7 +368,7 @@ def test_peak_detection_with_pipeline(recording, job_kwargs, torch_job_kwargs): pipeline_nodes = [ extract_dense_waveforms, PeakToPeakFeature(recording, all_channels=False, parents=[extract_dense_waveforms]), - LocalizeCenterOfMass(recording, local_radius_um=50.0, parents=[extract_dense_waveforms]), + LocalizeCenterOfMass(recording, radius_um=50.0, parents=[extract_dense_waveforms]), ] peaks, ptp, peak_locations = detect_peaks( recording, diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py index c4192c5fcf..34bc93fbfa 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_temporal_pca.py @@ -83,7 +83,7 @@ def test_pca_denoising_sparse(mearec_recording, detected_peaks, model_path_of_tr peaks = detected_peaks # Parameters - local_radius_um = 40 + radius_um = 40 ms_before = 1.0 ms_after = 1.0 @@ -94,7 +94,7 @@ def test_pca_denoising_sparse(mearec_recording, detected_peaks, model_path_of_tr parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, - local_radius_um=local_radius_um, + radius_um=radius_um, return_output=True, ) pca_denoising = TemporalPCADenoising( @@ -143,7 +143,7 @@ def test_pca_projection_sparsity(mearec_recording, detected_peaks, model_path_of peaks = detected_peaks # Parameters - local_radius_um = 40 + radius_um = 40 ms_before = 1.0 ms_after = 1.0 @@ -154,7 +154,7 @@ def test_pca_projection_sparsity(mearec_recording, detected_peaks, model_path_of parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, - local_radius_um=local_radius_um, + radius_um=radius_um, return_output=True, ) temporal_pca = TemporalPCAProjection( @@ -181,7 +181,7 @@ def test_initialization_with_wrong_parents_failure(mearec_recording, model_path_ model_folder_path = model_path_of_trained_pca dummy_parent = PipelineNode(recording=recording) extract_waveforms = ExtractSparseWaveforms( - recording=recording, ms_before=1, ms_after=1, local_radius_um=40, return_output=True + recording=recording, ms_before=1, ms_after=1, radius_um=40, return_output=True ) match_error = f"TemporalPCA should have a single {WaveformsNode.__name__} in its parents" diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 5283fd0f99..14b66fc847 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -29,7 +29,7 @@ def get_prototype_spike(recording, peaks, job_kwargs, nb_peaks=1000, ms_before=0 ms_before=ms_before, ms_after=ms_after, return_output=True, - local_radius_um=5, + radius_um=5, ) nbefore = sparse_waveforms.nbefore diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index de96fe445a..28cf8a3be0 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -93,7 +93,7 @@ def fit( ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - local_radius_um: float = None, + radius_um: float = None, ) -> IncrementalPCA: """ Train a pca model using the data in the recording object and the parameters provided. @@ -114,7 +114,7 @@ def fit( The parameters for peak selection. whiten : bool, optional Whether to whiten the data, by default True. - local_radius_um : float, optional + radius_um : float, optional The radius (in micrometers) to use for definint sparsity, by default None. ms_before : float, optional The number of milliseconds to include before the peak of the spike, by default 1. @@ -148,7 +148,7 @@ def fit( ) # compute PCA by_channel_global (with sparsity) - sparsity = ChannelSparsity.from_radius(we, radius_um=local_radius_um) if local_radius_um else None + sparsity = ChannelSparsity.from_radius(we, radius_um=radius_um) if radius_um else None pc = compute_principal_components( we, n_components=n_components, mode="by_channel_global", sparsity=sparsity, whiten=whiten ) From 20844cc8fd95f81efc8112a93cd40c9eef1f19b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 13:24:01 +0000 Subject: [PATCH 139/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/peak_localization.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 2e61f00ae7..bd793b3f53 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -158,12 +158,8 @@ class LocalizeCenterOfMass(LocalizeBase): Feature to consider for computation. Default is 'ptp' """ - def __init__( - self, recording, return_output=True, parents=["extract_waveforms"], radius_um=75.0, feature="ptp" - ): - LocalizeBase.__init__( - self, recording, return_output=return_output, parents=parents, radius_um=radius_um - ) + def __init__(self, recording, return_output=True, parents=["extract_waveforms"], radius_um=75.0, feature="ptp"): + LocalizeBase.__init__(self, recording, return_output=return_output, parents=parents, radius_um=radius_um) self._dtype = np.dtype(dtype_localize_by_method["center_of_mass"]) assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature" @@ -240,9 +236,7 @@ def __init__( enforce_decrease=True, feature="ptp", ): - LocalizeBase.__init__( - self, recording, return_output=return_output, parents=parents, radius_um=radius_um - ) + LocalizeBase.__init__(self, recording, return_output=return_output, parents=parents, radius_um=radius_um) assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature" self.max_distance_um = max_distance_um From b13aff171fdf27f155342483edeeb8972d7cd8f5 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 20 Jul 2023 16:07:42 +0200 Subject: [PATCH 140/166] Update pyproject.toml Co-authored-by: Alessio Buccino --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 59afcff264..3ecfbe2718 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.99.0" +version = "0.99.0.dev0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, From aa3b7c47318552060aae9b13ae9c1e5b3db0a080 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Jul 2023 09:47:28 +0200 Subject: [PATCH 141/166] fix plot_trace legend --- src/spikeinterface/widgets/traces.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index c9dc04811a..405c4b6b79 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -290,7 +290,8 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): check_ipywidget_backend() self.next_data_plot = data_plot.copy() - + self.next_data_plot["add_legend"] = False + recordings = data_plot["recordings"] # first layer From 1855b8dca1b929428be0560875d21af30f6fcf48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 07:48:13 +0000 Subject: [PATCH 142/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/traces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 405c4b6b79..9a2ec4a215 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -291,7 +291,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.next_data_plot = data_plot.copy() self.next_data_plot["add_legend"] = False - + recordings = data_plot["recordings"] # first layer From 1982bc4dab906c2b01741700c8a6b21d008f8d3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 21 Jul 2023 12:11:13 +0200 Subject: [PATCH 143/166] Added `ZarrRecordingExtractor` to recording list --- src/spikeinterface/extractors/extractorlist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 407d388044..6630c7b2c9 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -9,6 +9,7 @@ NpzSortingExtractor, NumpySorting, NpySnippetsExtractor, + ZarrRecordingExtractor, ) # sorting/recording/event from neo @@ -58,6 +59,7 @@ recording_extractor_full_list = [ BinaryRecordingExtractor, + ZarrRecordingExtractor, # natively implemented in spikeinterface.extractors NumpyRecording, SHYBRIDRecordingExtractor, From 0fbce884eb7ef7038a53d48229a6054f8236d1fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20WYNGAARD?= Date: Fri, 21 Jul 2023 12:36:31 +0200 Subject: [PATCH 144/166] Also added `BinaryFolderRecording` --- src/spikeinterface/extractors/extractorlist.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index 6630c7b2c9..ebff40fae0 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -4,6 +4,7 @@ from spikeinterface.core import ( BaseRecording, BaseSorting, + BinaryFolderRecording, BinaryRecordingExtractor, NumpyRecording, NpzSortingExtractor, @@ -58,6 +59,7 @@ ######################################## recording_extractor_full_list = [ + BinaryFolderRecording, BinaryRecordingExtractor, ZarrRecordingExtractor, # natively implemented in spikeinterface.extractors From 734670510ef3d58f2943ec1e36b79b1e4f6b2a98 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 21 Jul 2023 14:13:04 +0200 Subject: [PATCH 145/166] feedback from Alessio --- src/spikeinterface/widgets/all_amplitudes_distributions.py | 3 --- src/spikeinterface/widgets/quality_metrics.py | 2 -- src/spikeinterface/widgets/sorting_summary.py | 2 -- src/spikeinterface/widgets/template_similarity.py | 2 -- src/spikeinterface/widgets/tests/test_widgets.py | 2 -- src/spikeinterface/widgets/unit_depths.py | 2 -- src/spikeinterface/widgets/unit_locations.py | 2 -- 7 files changed, 15 deletions(-) diff --git a/src/spikeinterface/widgets/all_amplitudes_distributions.py b/src/spikeinterface/widgets/all_amplitudes_distributions.py index 280662fd7a..e8b25f6823 100644 --- a/src/spikeinterface/widgets/all_amplitudes_distributions.py +++ b/src/spikeinterface/widgets/all_amplitudes_distributions.py @@ -50,9 +50,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from matplotlib.patches import Ellipse - from matplotlib.lines import Line2D - dp = to_attr(data_plot) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) diff --git a/src/spikeinterface/widgets/quality_metrics.py b/src/spikeinterface/widgets/quality_metrics.py index 459a32e6f2..4a6b46b72d 100644 --- a/src/spikeinterface/widgets/quality_metrics.py +++ b/src/spikeinterface/widgets/quality_metrics.py @@ -22,8 +22,6 @@ class QualityMetricsWidget(MetricsBaseWidget): For sortingview backend, if True the unit selector is not displayed, default False """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 9291de2956..b9760205f9 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -34,8 +34,6 @@ class SortingSummaryWidget(BaseWidget): (sortingview backend) """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/template_similarity.py b/src/spikeinterface/widgets/template_similarity.py index 69aad70b1f..63ac177835 100644 --- a/src/spikeinterface/widgets/template_similarity.py +++ b/src/spikeinterface/widgets/template_similarity.py @@ -25,8 +25,6 @@ class TemplateSimilarityWidget(BaseWidget): If True, color bar is displayed, default True. """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 96c6ab80eb..7bf508fe71 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -13,7 +13,6 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity -# from spikeinterface.widgets import HAVE_MPL, HAVE_SV import spikeinterface.extractors as se @@ -36,7 +35,6 @@ else: cache_folder = Path("cache_folder") / "widgets" -print(cache_folder) ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index e48f274962..1aeae254c8 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -24,8 +24,6 @@ class UnitDepthsWidget(BaseWidget): Sign of peak for amplitudes, default 'neg' """ - # possible_backends = {} - def __init__( self, waveform_extractor, unit_colors=None, depth_axis=1, peak_sign="neg", backend=None, **backend_kwargs ): diff --git a/src/spikeinterface/widgets/unit_locations.py b/src/spikeinterface/widgets/unit_locations.py index f8ea042f84..42267e711f 100644 --- a/src/spikeinterface/widgets/unit_locations.py +++ b/src/spikeinterface/widgets/unit_locations.py @@ -33,8 +33,6 @@ class UnitLocationsWidget(BaseWidget): If True, the axis is set to off, default False (matplotlib backend) """ - # possible_backends = {} - def __init__( self, waveform_extractor: WaveformExtractor, From 370dc66ab2f89b51d411b366e9d650443852043d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Jul 2023 12:16:34 +0000 Subject: [PATCH 146/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/tests/test_widgets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7bf508fe71..a5f75ebf50 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -14,7 +14,6 @@ from spikeinterface import extract_waveforms, load_waveforms, download_dataset, compute_sparsity - import spikeinterface.extractors as se import spikeinterface.widgets as sw import spikeinterface.comparison as sc From fbebdec5e2ab66f6328d6eaf3ddea81fc0efc26c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 22 Jul 2023 21:11:36 +0200 Subject: [PATCH 147/166] Fix missing import pandas --- src/spikeinterface/comparison/paircomparisons.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 97269edc76..cd2b97aaf1 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -395,6 +395,8 @@ def get_performance(self, method="by_unit", output="pandas"): perf: pandas dataframe/series (or dict) dataframe/series (based on 'output') with performance entries """ + import pandas as pd + possibles = ("raw_count", "by_unit", "pooled_with_average") if method not in possibles: raise Exception("'method' can be " + " or ".join(possibles)) From 4c15161d60ef8d024b00f75d49e3285565526704 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:24:50 +0200 Subject: [PATCH 148/166] Add simple test for get_performance output type --- .../comparison/tests/test_groundtruthcomparison.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py index a2f043b9e7..03c2418411 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py @@ -55,8 +55,12 @@ def test_compare_sorter_to_ground_truth(): "pooled_with_average", ] for method in methods: - perf = sc.get_performance(method=method) - # ~ print(perf) + import pandas as pd + + perf_df = sc.get_performance(method=method, output="pandas") + assert isinstance(perf_df, pd.DataFrame) + perf_dict = sc.get_performance(method=method, output="dict") + assert isinstance(perf_dict, dict) for method in methods: sc.print_performance(method=method) From 3bb91cdf94b4429debc3820843e079c879eba3ee Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:47:47 +0200 Subject: [PATCH 149/166] Fix output='dict' for get_performance --- src/spikeinterface/comparison/paircomparisons.py | 2 +- .../comparison/tests/test_groundtruthcomparison.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index cd2b97aaf1..3dc58b7a52 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -410,7 +410,7 @@ def get_performance(self, method="by_unit", output="pandas"): elif method == "pooled_with_average": perf = self.get_performance(method="by_unit").mean(axis=0) - if output == "dict" and isinstance(perf, pd.Series): + if output == "dict": perf = perf.to_dict() return perf diff --git a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py index 03c2418411..fb3ee5d454 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py @@ -58,7 +58,7 @@ def test_compare_sorter_to_ground_truth(): import pandas as pd perf_df = sc.get_performance(method=method, output="pandas") - assert isinstance(perf_df, pd.DataFrame) + assert isinstance(perf_df, (pd.Series, pd.DataFrame)) perf_dict = sc.get_performance(method=method, output="dict") assert isinstance(perf_dict, dict) From e19c3cd3ddd9d08c9abc2203aa6a64542fcb5055 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 10:48:23 +0200 Subject: [PATCH 150/166] Fix output='dict' for get_performance 1 --- src/spikeinterface/comparison/paircomparisons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 3dc58b7a52..75976ed44f 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -410,7 +410,7 @@ def get_performance(self, method="by_unit", output="pandas"): elif method == "pooled_with_average": perf = self.get_performance(method="by_unit").mean(axis=0) - if output == "dict": + if output == "dict" and isinstance(perf, (pd.DataFrame, pd.Series)): perf = perf.to_dict() return perf From 131715fa00dd7f4894f2dee218da42df9655c14d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 24 Jul 2023 11:50:46 +0200 Subject: [PATCH 151/166] For the cronjob, avoid re-downloading the cache dataset if present (#1862) * avoid downloading the dataset if present * added to both caches * Update .github/workflows/caches_cron_job.yml * Update .github/workflows/caches_cron_job.yml --- .github/workflows/caches_cron_job.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/caches_cron_job.yml b/.github/workflows/caches_cron_job.yml index 3ed91b84c4..237612d5d3 100644 --- a/.github/workflows/caches_cron_job.yml +++ b/.github/workflows/caches_cron_job.yml @@ -33,6 +33,7 @@ jobs: with: path: ${{ github.workspace }}/test_env key: ${{ runner.os }}-venv-${{ steps.dependencies.outputs.hash }}-${{ steps.date.outputs.date }} + lookup-only: 'true' # Avoids downloading the data, saving behavior is not affected. - name: Cache found? run: echo "Cache-hit == ${{steps.cache-venv.outputs.cache-hit == 'true'}}" - name: Create the virtual environment to be cached @@ -64,6 +65,7 @@ jobs: with: path: ~/spikeinterface_datasets key: ${{ runner.os }}-datasets-${{ steps.repo_hash.outputs.dataset_hash }} + lookup-only: 'true' # Avoids downloading the data, saving behavior is not affected. - name: Cache found? run: echo "Cache-hit == ${{steps.cache-datasets.outputs.cache-hit == 'true'}}" - name: Installing datalad and git-annex @@ -88,7 +90,7 @@ jobs: run: | cd $HOME pwd - du -hs spikeinterface_datasets + du -hs spikeinterface_datasets # Should show the size of ephy_testing_data cd spikeinterface_datasets pwd ls -lh # Should show ephy_testing_data From 63442db37780d61ec528ef574b3724244a2a8ae9 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 24 Jul 2023 15:44:54 +0200 Subject: [PATCH 152/166] Update src/spikeinterface/core/basesorting.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/basesorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index bad007aeae..91d820153c 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -508,7 +508,7 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac def to_numpy_sorting(self, propagate_cache=True): """ Turn any sorting in a NumpySorting. - usefull to have it in memory with a unique vector representation. + useful to have it in memory with a unique vector representation. Parameters ---------- From 25527fd3ad8f17ca8d8cdb89514150d983b5d31f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 24 Jul 2023 16:03:20 +0200 Subject: [PATCH 153/166] Alessio suggestion for more test --- src/spikeinterface/core/tests/test_basesorting.py | 5 +++++ src/spikeinterface/core/tests/test_waveform_extractor.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index d7559d9567..0bdd9aecdd 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -13,6 +13,7 @@ NpzSortingExtractor, NumpyRecording, NumpySorting, + SharedMemorySorting, NpzFolderSorting, NumpyFolderSorting, create_sorting_npz, @@ -121,6 +122,10 @@ def test_BaseSorting(): sorting4 = sorting.to_numpy_sorting() sorting5 = sorting.to_multiprocessing(n_jobs=2) + # create a clone with the same share mem buffer + sorting6 = load_extractor(sorting5.to_dict()) + assert isinstance(sorting6, SharedMemorySorting) + del sorting6 del sorting5 diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index e9d0462359..65dcff08d7 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -212,7 +212,8 @@ def test_extract_waveforms(): if folder_sort.is_dir(): shutil.rmtree(folder_sort) recording = recording.save(folder=folder_rec) - sorting = sorting.save(folder=folder_sort) + # we force "npz_folder" because we want to force the to_multiprocessing to be a SharedMemorySorting + sorting = sorting.save(folder=folder_sort, format='npz_folder') # 1 job folder1 = cache_folder / "test_extract_waveforms_1job" From 0f11e98e8cf02890860b0347ed5c8496faea0b58 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jul 2023 14:04:25 +0000 Subject: [PATCH 154/166] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_waveform_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 65dcff08d7..107ef5f180 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -213,7 +213,7 @@ def test_extract_waveforms(): shutil.rmtree(folder_sort) recording = recording.save(folder=folder_rec) # we force "npz_folder" because we want to force the to_multiprocessing to be a SharedMemorySorting - sorting = sorting.save(folder=folder_sort, format='npz_folder') + sorting = sorting.save(folder=folder_sort, format="npz_folder") # 1 job folder1 = cache_folder / "test_extract_waveforms_1job" From 604ca6a2d9712f904f5c7a176802c2cf3d98d571 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 16:08:56 +0200 Subject: [PATCH 155/166] Move test imports to top --- .../comparison/tests/test_groundtruthcomparison.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py index fb3ee5d454..931c989cef 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthcomparison.py @@ -1,6 +1,8 @@ import numpy as np from numpy.testing import assert_array_equal +import pandas as pd + from spikeinterface.extractors import NumpySorting, toy_example from spikeinterface.comparison import compare_sorter_to_ground_truth @@ -55,8 +57,6 @@ def test_compare_sorter_to_ground_truth(): "pooled_with_average", ] for method in methods: - import pandas as pd - perf_df = sc.get_performance(method=method, output="pandas") assert isinstance(perf_df, (pd.Series, pd.DataFrame)) perf_dict = sc.get_performance(method=method, output="dict") From 515cfc33da68feb36b1c1583c39ca0b15fcaad14 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 19:38:05 +0200 Subject: [PATCH 156/166] Fix warnings in PCA metrics --- src/spikeinterface/qualitymetrics/pca_metrics.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2a0feb4da8..97fc4aa14f 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -466,15 +466,12 @@ def nearest_neighbors_isolation( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than ", - f"specified by `min_spikes` ({min_spikes}); ", - f"returning NaN as the quality metric...", + f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` ({min_spikes}); returning NaN as the quality metric..." ) return np.nan, np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate ", - f"below the specified `min_fr` ({min_fr}Hz); " f"returning NaN as the quality metric...", + f"Warning: unit {this_uit_id} has a firing rate below the specified `min_fr` ({min_fr}Hz); returning NaN as the quality metric..." ) return np.nan, np.nan else: @@ -652,15 +649,12 @@ def nearest_neighbors_noise_overlap( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than ", - f"specified by `min_spikes` ({min_spikes}); ", - f"returning NaN as the quality metric...", + f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` ({min_spikes}); returning NaN as the quality metric..." ) return np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate ", - f"below the specified `min_fr` ({min_fr}Hz); " f"returning NaN as the quality metric...", + f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` ({min_fr}Hz); returning NaN as the quality metric...", ) return np.nan else: From fb4e3d903a588695ceb750c25f29a856eaef0e5d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 19:39:48 +0200 Subject: [PATCH 157/166] line breaks --- src/spikeinterface/qualitymetrics/pca_metrics.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 97fc4aa14f..1644559416 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -466,12 +466,14 @@ def nearest_neighbors_isolation( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` ({min_spikes}); returning NaN as the quality metric..." + f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` " + f"({min_spikes}); returning NaN as the quality metric..." ) return np.nan, np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_uit_id} has a firing rate below the specified `min_fr` ({min_fr}Hz); returning NaN as the quality metric..." + f"Warning: unit {this_uit_id} has a firing rate below the specified `min_fr` " + f"({min_fr}Hz); returning NaN as the quality metric..." ) return np.nan, np.nan else: @@ -649,12 +651,14 @@ def nearest_neighbors_noise_overlap( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` ({min_spikes}); returning NaN as the quality metric..." + f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` " + f"({min_spikes}); returning NaN as the quality metric..." ) return np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` ({min_fr}Hz); returning NaN as the quality metric...", + f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` " + f"({min_fr}Hz); returning NaN as the quality metric...", ) return np.nan else: From d67a2a791963fa74f1d3afcb0f10c0ede74b22ce Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 20:10:13 +0200 Subject: [PATCH 158/166] Fix NWB streaming: do not convert to Path if ros3 or fsspec! --- src/spikeinterface/extractors/nwbextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d0b56342dd..b50ac76d26 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -104,7 +104,6 @@ def read_nwbfile( -------- >>> nwbfile = read_nwbfile("data.nwb", stream_mode="ros3") """ - file_path = str(Path(file_path).absolute()) from pynwb import NWBHDF5IO, NWBFile if stream_mode == "fsspec": @@ -131,6 +130,7 @@ def read_nwbfile( io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True, driver="ros3") else: + file_path = str(Path(file_path).absolute()) io = NWBHDF5IO(path=file_path, mode="r", load_namespaces=True) nwbfile = io.read() From 88e70f1229eaaf0041d11c4edf5dfb9a952fa403 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Jul 2023 20:23:12 +0200 Subject: [PATCH 159/166] Trigger streaming tests if streaming files changed --- .../workflows/streaming-extractor-test.yml | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/streaming-extractor-test.yml b/.github/workflows/streaming-extractor-test.yml index 1498684d77..37f83dc666 100644 --- a/.github/workflows/streaming-extractor-test.yml +++ b/.github/workflows/streaming-extractor-test.yml @@ -1,6 +1,10 @@ name: Test streaming extractors -on: workflow_dispatch +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} @@ -28,9 +32,20 @@ jobs: - run: git fetch --prune --unshallow --tags - name: Install openblas run: sudo apt install libopenblas-dev # Necessary for ROS3 support - - name: Install package and streaming extractor dependencies + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v35 + - name: Module changes + id: modules-changed run: | - pip install -e .[test_core,streaming_extractors] + for file in ${{ steps.changed-files.outputs.all_changed_files }}; do + if [[ $file == *"nwbextractors.py" || $file == *"iblstreamingrecording.py" ]]; then + echo "Streaming files changed changed" + echo "STREAMING_CHANGED=true" >> $GITHUB_OUTPUT + fi + - name: Install package and streaming extractor dependencies + if: steps.modules-changed.outputs.STREAMING_CHANGED == 'true' + run: pip install -e .[test_core,streaming_extractors] # Temporary disabled because of complicated error with path # - name: Install h5py with ROS3 support and test it works # run: | @@ -38,4 +53,5 @@ jobs: # conda install -c conda-forge "h5py>=3.2" # python -c "import h5py; assert 'ros3' in h5py.registered_drivers(), f'ros3 suppport not available, failed to install'" - name: run tests + if: steps.modules-changed.outputs.STREAMING_CHANGED == 'true' run: pytest -m "streaming_extractors and not ros3_test" -vv -ra From 3a657ab835e3c6734a5be9a9804982b4599d1548 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 09:06:04 +0200 Subject: [PATCH 160/166] Fix action --- .github/workflows/streaming-extractor-test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/streaming-extractor-test.yml b/.github/workflows/streaming-extractor-test.yml index 37f83dc666..064d38fcc4 100644 --- a/.github/workflows/streaming-extractor-test.yml +++ b/.github/workflows/streaming-extractor-test.yml @@ -39,12 +39,13 @@ jobs: id: modules-changed run: | for file in ${{ steps.changed-files.outputs.all_changed_files }}; do - if [[ $file == *"nwbextractors.py" || $file == *"iblstreamingrecording.py" ]]; then + if [[ $file == *"/nwbextractors.py" || $file == *"/iblstreamingrecording.py"* ]]; then echo "Streaming files changed changed" echo "STREAMING_CHANGED=true" >> $GITHUB_OUTPUT fi + done - name: Install package and streaming extractor dependencies - if: steps.modules-changed.outputs.STREAMING_CHANGED == 'true' + if: ${{ steps.modules-changed.outputs.STREAMING_CHANGED == 'true' }} run: pip install -e .[test_core,streaming_extractors] # Temporary disabled because of complicated error with path # - name: Install h5py with ROS3 support and test it works From e77dcd9d12ab37dc048b4328ad158250b4c9d22c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 10:11:36 +0200 Subject: [PATCH 161/166] Fix NwbSorting streaming and add tests --- .../extractors/nwbextractors.py | 14 ++--- .../extractors/tests/test_nwb_s3_extractor.py | 59 ++++++++++++++++++- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b50ac76d26..bca4c75d99 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -475,19 +475,17 @@ def __init__( self.stream_cache_path = stream_cache_path if stream_cache_path is not None else "cache" self.cfs = CachingFileSystem( fs=fsspec.filesystem("http"), - cache_storage=self.stream_cache_path, + cache_storage=str(self.stream_cache_path), ) - self._file_path = self.cfs.open(str(Path(file_path).absolute()), "rb") - file = h5py.File(self._file_path) + file_path_ = self.cfs.open(file_path, "rb") + file = h5py.File(file_path_) self.io = NWBHDF5IO(file=file, mode="r", load_namespaces=True) elif stream_mode == "ros3": - self._file_path = str(Path(file_path).absolute()) - self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True, driver="ros3") - + self.io = NWBHDF5IO(file_path, mode="r", load_namespaces=True, driver="ros3") else: - self._file_path = str(Path(file_path).absolute()) - self.io = NWBHDF5IO(self._file_path, mode="r", load_namespaces=True) + file_path_ = str(Path(file_path).absolute()) + self.io = NWBHDF5IO(file_path_, mode="r", load_namespaces=True) self._nwbfile = self.io.read() units_ids = list(self._nwbfile.units.id[:]) diff --git a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py index a41fd080b4..71a19f30d3 100644 --- a/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py +++ b/src/spikeinterface/extractors/tests/test_nwb_s3_extractor.py @@ -4,7 +4,7 @@ import numpy as np import h5py -from spikeinterface.extractors import NwbRecordingExtractor +from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "extractors" @@ -15,7 +15,7 @@ @pytest.mark.ros3_test @pytest.mark.streaming_extractors @pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") -def test_s3_nwb_ros3(): +def test_recording_s3_nwb_ros3(): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -42,7 +42,7 @@ def test_s3_nwb_ros3(): @pytest.mark.streaming_extractors -def test_s3_nwb_fsspec(): +def test_recording_s3_nwb_fsspec(): file_path = ( "https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/5f4/b7a/5f4b7a1f-7b95-4ad8-9579-4df6025371cc" ) @@ -66,3 +66,56 @@ def test_s3_nwb_fsspec(): if rec.has_scaled(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" + + +@pytest.mark.ros3_test +@pytest.mark.streaming_extractors +@pytest.mark.skipif("ros3" not in h5py.registered_drivers(), reason="ROS3 driver not installed") +def test_sorting_s3_nwb_ros3(): + file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" + # we provide the 'sampling_frequency' because the NWB file does not the electrical series + sort = NwbSortingExtractor(file_path, sampling_frequency=30000, stream_mode="ros3") + + start_frame = 0 + end_frame = 300 + num_frames = end_frame - start_frame + + num_seg = sort.get_num_segments() + num_units = len(sort.unit_ids) + + for segment_index in range(num_seg): + for unit in sort.unit_ids: + spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index) + assert len(spike_train) > 0 + assert spike_train.dtype == "int64" + assert np.all(spike_train >= 0) + + +@pytest.mark.streaming_extractors +def test_sorting_s3_nwb_fsspec(): + file_path = "https://dandiarchive.s3.amazonaws.com/blobs/84b/aa4/84baa446-cf19-43e8-bdeb-fc804852279b" + # we provide the 'sampling_frequency' because the NWB file does not the electrical series + sort = NwbSortingExtractor( + file_path, sampling_frequency=30000, stream_mode="fsspec", stream_cache_path=cache_folder + ) + + start_frame = 0 + end_frame = 300 + num_frames = end_frame - start_frame + + num_seg = sort.get_num_segments() + num_units = len(sort.unit_ids) + + for segment_index in range(num_seg): + for unit in sort.unit_ids: + spike_train = sort.get_unit_spike_train(unit_id=unit, segment_index=segment_index) + assert len(spike_train) > 0 + assert spike_train.dtype == "int64" + assert np.all(spike_train >= 0) + + +if __name__ == "__main__": + test_recording_s3_nwb_ros3() + test_recording_s3_nwb_fsspec() + test_sorting_s3_nwb_ros3() + test_sorting_s3_nwb_fsspec() From 5fbd38872f4fc89ac2b0c8fbb0ec29a8c5f164e8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 10:32:47 +0200 Subject: [PATCH 162/166] Update src/spikeinterface/qualitymetrics/pca_metrics.py --- src/spikeinterface/qualitymetrics/pca_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 1644559416..45d5e80379 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -472,7 +472,7 @@ def nearest_neighbors_isolation( return np.nan, np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_uit_id} has a firing rate below the specified `min_fr` " + f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` " f"({min_fr}Hz); returning NaN as the quality metric..." ) return np.nan, np.nan From 9904c35ae924e9d2b17a2feb43729cd5991fc354 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 10:33:00 +0200 Subject: [PATCH 163/166] Update src/spikeinterface/qualitymetrics/pca_metrics.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/qualitymetrics/pca_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 45d5e80379..97cb746363 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -473,7 +473,7 @@ def nearest_neighbors_isolation( elif fr_all_units[this_unit_id] < min_fr: warnings.warn( f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` " - f"({min_fr}Hz); returning NaN as the quality metric..." + f"({min_fr} Hz); returning NaN as the quality metric..." ) return np.nan, np.nan else: From f546932f923190df247d1dee3363c729fb182176 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 10:33:07 +0200 Subject: [PATCH 164/166] Update src/spikeinterface/qualitymetrics/pca_metrics.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/qualitymetrics/pca_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 97cb746363..d4eff9218e 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -658,7 +658,7 @@ def nearest_neighbors_noise_overlap( elif fr_all_units[this_unit_id] < min_fr: warnings.warn( f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` " - f"({min_fr}Hz); returning NaN as the quality metric...", + f"({min_fr} Hz); returning NaN as the quality metric...", ) return np.nan else: From f4c511b3a49ccd1736c7430664d71c33611938e1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Jul 2023 10:58:51 +0200 Subject: [PATCH 165/166] Remove redundant warning --- src/spikeinterface/qualitymetrics/pca_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index d4eff9218e..e725498773 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -466,13 +466,13 @@ def nearest_neighbors_isolation( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` " + f"Unit {this_unit_id} has fewer spikes than specified by `min_spikes` " f"({min_spikes}); returning NaN as the quality metric..." ) return np.nan, np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` " + f"Unit {this_unit_id} has a firing rate below the specified `min_fr` " f"({min_fr} Hz); returning NaN as the quality metric..." ) return np.nan, np.nan @@ -651,13 +651,13 @@ def nearest_neighbors_noise_overlap( # if target unit has fewer than `min_spikes` spikes, print out a warning and return NaN if n_spikes_all_units[this_unit_id] < min_spikes: warnings.warn( - f"Warning: unit {this_unit_id} has fewer spikes than specified by `min_spikes` " + f"Unit {this_unit_id} has fewer spikes than specified by `min_spikes` " f"({min_spikes}); returning NaN as the quality metric..." ) return np.nan elif fr_all_units[this_unit_id] < min_fr: warnings.warn( - f"Warning: unit {this_unit_id} has a firing rate below the specified `min_fr` " + f"Unit {this_unit_id} has a firing rate below the specified `min_fr` " f"({min_fr} Hz); returning NaN as the quality metric...", ) return np.nan From 4245f0f822ca0d48e50a24ab94fb5f705ce856da Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 26 Jul 2023 16:01:00 +0200 Subject: [PATCH 166/166] Comments from Ramon. --- src/spikeinterface/core/basesorting.py | 6 +++--- src/spikeinterface/core/sortingfolder.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 91d820153c..56f46f0a38 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -421,15 +421,15 @@ def to_spike_vector(self, concatenated=True, extremum_channel_inds=None, use_cac Parameters ---------- concatenated: bool - By default the output is one numpy vector with all spikes from all segments - With concatenated=False then it is a list of spike vector by segment. + With concatenated=True (default) the output is one numpy "spike vector" with spikes from all segments. + With concatenated=False the output is a list "spike vector" by segment. extremum_channel_inds: None or dict If a dictionnary of unit_id to channel_ind is given then an extra field 'channel_index'. This can be convinient for computing spikes postion after sorter. This dict can be computed with `get_template_extremum_channel(we, outputs="index")` use_cache: bool - When True (default) the spikes vector is cache in an attribute of the object. + When True (default) the spikes vector is cached as an attribute of the object (`_cached_spike_vector`). This caching only occurs when extremum_channel_inds=None. Returns diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index 49619bca06..c813c26442 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -31,11 +31,11 @@ def __init__(self, folder_path): folder_path = Path(folder_path) with open(folder_path / "numpysorting_info.json", "r") as f: - d = json.load(f) + info = json.load(f) - sampling_frequency = d["sampling_frequency"] - unit_ids = np.array(d["unit_ids"]) - num_segments = d["num_segments"] + sampling_frequency = info["sampling_frequency"] + unit_ids = np.array(info["unit_ids"]) + num_segments = info["num_segments"] BaseSorting.__init__(self, sampling_frequency, unit_ids)