From f498179431d23291f72bbea40bc6a95b65dc5913 Mon Sep 17 00:00:00 2001 From: Julia Sprenger Date: Thu, 22 Jun 2023 10:29:01 +0200 Subject: [PATCH 01/73] Add plexon2 recording, sorting and event support --- .../extractors/neoextractors/__init__.py | 7 +- .../extractors/neoextractors/plexon2.py | 102 ++++++++++++++++++ .../extractors/tests/test_neoextractors.py | 18 ++++ 3 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 src/spikeinterface/extractors/neoextractors/plexon2.py diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0d9da1960a..4c12017328 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -25,6 +25,8 @@ read_openephys_event, ) from .plexon import PlexonRecordingExtractor, PlexonSortingExtractor, read_plexon, read_plexon_sorting +from .plexon2 import (Plexon2SortingExtractor, Plexon2RecordingExtractor, Plexon2EventExtractor, + read_plexon2, read_plexon2_sorting, read_plexon2_event) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx @@ -49,12 +51,13 @@ OpenEphysBinaryRecordingExtractor, OpenEphysLegacyRecordingExtractor, PlexonRecordingExtractor, + Plexon2RecordingExtractor, Spike2RecordingExtractor, SpikeGadgetsRecordingExtractor, SpikeGLXRecordingExtractor, TdtRecordingExtractor, ] -neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor] +neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor, Plexon2SortingExtractor] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor] +neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py new file mode 100644 index 0000000000..5ccceac875 --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -0,0 +1,102 @@ +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import (NeoBaseRecordingExtractor, NeoBaseSortingExtractor, + NeoBaseEventExtractor) + + +class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading plexon pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + stream_id: str, optional + If there are several streams, specify the stream id you want to load. + stream_name: str, optional + If there are several streams, specify the stream name you want to load. + all_annotations: bool, default: False + Load exhaustively all annotations from neo. + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + ) + self._kwargs.update({"file_path": str(file_path)}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2SortingExtractor(NeoBaseSortingExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + handle_spike_frame_directly = True + name = "plexon2" + + def __init__(self, file_path): + from neo.rawio import Plexon2RawIO + + neo_kwargs = self.map_to_neo_kwargs(file_path) + neo_reader = Plexon2RawIO(**neo_kwargs) + neo_reader.parse_header() + NeoBaseSortingExtractor.__init__(self, **neo_kwargs) + self._kwargs.update({"file_path": str(file_path)}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2EventExtractor(NeoBaseEventExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + folder_path: str + + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, folder_path, block_index=None): + neo_kwargs = self.map_to_neo_kwargs(folder_path) + NeoBaseEventExtractor.__init__(self, block_index=block_index, **neo_kwargs) + + @classmethod + def map_to_neo_kwargs(cls, folder_path): + neo_kwargs = {"filename": str(folder_path)} + return neo_kwargs + + +read_plexon2 = define_function_from_class(source_class=Plexon2RecordingExtractor, name="read_plexon2") +read_plexon2_sorting = define_function_from_class(source_class=Plexon2SortingExtractor, name="read_plexon2_sorting") +read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name='read_plexon2_event') diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index d19574e094..a28752acdd 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -120,6 +120,12 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): }, ] +class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2EventExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor @@ -128,6 +134,12 @@ class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): "plexon/File_plexon_3.plx", ] +class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2RecordingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ] class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonSortingExtractor @@ -136,6 +148,12 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ("plexon/File_plexon_1.plx"), ] +class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2SortingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuralynxRecordingExtractor From 9f42895213a47ddb9158e4cccb48dcec1dea9549 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jun 2023 08:31:48 +0000 Subject: [PATCH 02/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../extractors/neoextractors/__init__.py | 17 ++++++++++++++--- .../extractors/neoextractors/plexon2.py | 5 ++--- .../extractors/tests/test_neoextractors.py | 6 ++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 4c12017328..3360b76147 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -25,8 +25,14 @@ read_openephys_event, ) from .plexon import PlexonRecordingExtractor, PlexonSortingExtractor, read_plexon, read_plexon_sorting -from .plexon2 import (Plexon2SortingExtractor, Plexon2RecordingExtractor, Plexon2EventExtractor, - read_plexon2, read_plexon2_sorting, read_plexon2_event) +from .plexon2 import ( + Plexon2SortingExtractor, + Plexon2RecordingExtractor, + Plexon2EventExtractor, + read_plexon2, + read_plexon2_sorting, + read_plexon2_event, +) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx @@ -58,6 +64,11 @@ TdtRecordingExtractor, ] -neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor, Plexon2SortingExtractor] +neo_sorting_extractors_list = [ + BlackrockSortingExtractor, + MEArecSortingExtractor, + NeuralynxSortingExtractor, + Plexon2SortingExtractor, +] neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 5ccceac875..c3869dbadc 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -1,7 +1,6 @@ from spikeinterface.core.core_tools import define_function_from_class -from .neobaseextractor import (NeoBaseRecordingExtractor, NeoBaseSortingExtractor, - NeoBaseEventExtractor) +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): @@ -99,4 +98,4 @@ def map_to_neo_kwargs(cls, folder_path): read_plexon2 = define_function_from_class(source_class=Plexon2RecordingExtractor, name="read_plexon2") read_plexon2_sorting = define_function_from_class(source_class=Plexon2SortingExtractor, name="read_plexon2_sorting") -read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name='read_plexon2_event') +read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name="read_plexon2_event") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index a28752acdd..b14bcc9cf8 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -120,6 +120,7 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): }, ] + class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor downloads = ["plexon"] @@ -127,6 +128,7 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ("plexon/4chDemoPL2.pl2"), ] + class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor downloads = ["plexon"] @@ -134,6 +136,7 @@ class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): "plexon/File_plexon_3.plx", ] + class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] @@ -141,6 +144,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), ] + class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonSortingExtractor downloads = ["plexon"] @@ -148,6 +152,7 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ("plexon/File_plexon_1.plx"), ] + class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] @@ -155,6 +160,7 @@ class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ("plexon/4chDemoPL2.pl2"), ] + class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuralynxRecordingExtractor downloads = ["neuralynx"] From a8e9924ac9d14dd7ec5f116112866846eac2e9e2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 10:42:07 +0200 Subject: [PATCH 03/73] Start waveforme xtarctor in one buffer --- src/spikeinterface/core/waveform_tools.py | 217 +++++++++++++++++++++- 1 file changed, 213 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a10c209f47..a68f8cfd5f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -257,8 +257,8 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _waveform_extractor_chunk - init_func = _init_worker_waveform_extractor + func = _worker_ditribute_buffers + init_func = _init_worker_ditribute_buffers init_args = ( recording, @@ -282,7 +282,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_waveform_extractor( +def _init_worker_ditribute_buffers( recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker @@ -328,7 +328,216 @@ def _init_worker_waveform_extractor( # used by ChunkRecordingExecutor -def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): +def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + unit_ids = worker_ctx["unit_ids"] + spikes = worker_ctx["spikes"] + nbefore = worker_ctx["nbefore"] + nafter = worker_ctx["nafter"] + return_scaled = worker_ctx["return_scaled"] + inds_by_unit = worker_ctx["inds_by_unit"] + sparsity_mask = worker_ctx["sparsity_mask"] + + seg_size = recording.get_num_samples(segment_index=segment_index) + + # take only spikes with the correct segment_index + # this is a slice so no copy!! + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + in_seg_spikes = spikes[s0:s1] + + # take only spikes in range [start_frame, end_frame] + # this is a slice so no copy!! + i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) + i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) + if i0 != i1: + # protect from spikes on border : spike_time<0 or spike_time>seg_size + # useful only when max_spikes_per_unit is not None + # waveform will not be extracted and a zeros will be left in the memmap file + while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): + i0 = i0 + 1 + while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): + i1 = i1 - 1 + + # slice in absolut in spikes vector + l0 = i0 + s0 + l1 = i1 + s0 + + if l1 > l0: + start = spikes[l0]["sample_index"] - nbefore + end = spikes[l1 - 1]["sample_index"] + nafter + + # load trace in memory + traces = recording.get_traces( + start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled + ) + + for unit_ind, unit_id in enumerate(unit_ids): + # find pos + inds = inds_by_unit[unit_id] + (in_chunk_pos,) = np.nonzero((inds >= l0) & (inds < l1)) + if in_chunk_pos.size == 0: + continue + + if worker_ctx["mode"] == "memmap": + # open file in demand (and also autoclose it after) + filename = worker_ctx["wfs_arrays_info"][unit_id] + wfs = np.load(str(filename), mmap_mode="r+") + elif worker_ctx["mode"] == "shared_memory": + wfs = worker_ctx["wfs_arrays"][unit_id] + + for pos in in_chunk_pos: + sample_index = spikes[inds[pos]]["sample_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + + if sparsity_mask is None: + wfs[pos, :, :] = wf + else: + wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] + + +def extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + folder=None, + dtype=None, + sparsity_mask=None, + copy=False, + **job_kwargs, +): + + nsamples = nbefore + nafter + + dtype = np.dtype(dtype) + if mode == "shared_memory": + assert folder is None + else: + folder = Path(folder) + + num_spikes = spike.size + if sparsity_mask is None: + num_chans = recording.get_num_channels() + else: + num_chans = np.sum(sparsity_mask[unit_ind, :]) + shape = (num_spikes, nsamples, num_chans) + + if mode == "memmap": + filename = str(folder / f"all_waveforms.npy") + wfs_array = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) + wf_array_info = filename + elif mode == "shared_memory": + if num_spikes == 0: + wfs_array = np.zeros(shape, dtype=dtype) + shm = None + shm_name = None + else: + wfs_array, shm = make_shared_array(shape, dtype) + shm_name = shm.name + wf_array_info = (shm, shm_name, dtype.str, shape) + else: + raise ValueError("allocate_waveforms_buffers bad mode") + + + job_kwargs = fix_job_kwargs(job_kwargs) + + inds_by_unit = {} + for unit_ind, unit_id in enumerate(unit_ids): + (inds,) = np.nonzero(spikes["unit_index"] == unit_ind) + inds_by_unit[unit_id] = inds + + if num_spikes > 0: + # and run + func = _worker_ditribute_one_buffer + init_func = _init_worker_ditribute_buffers + + init_args = ( + recording, + unit_ids, + spikes, + wf_array_info, + nbefore, + nafter, + return_scaled, + mode, + sparsity_mask, + ) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs + ) + processor.run() + + + # if mode == "memmap": + # return wfs_arrays + # elif mode == "shared_memory": + # if copy: + # wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + # # release all sharedmem buffer + # for unit_id in unit_ids: + # shm = wfs_arrays_info[unit_id][0] + # if shm is not None: + # # empty array have None + # shm.unlink() + # return wfs_arrays + # else: + # return wfs_arrays, wfs_arrays_info + + + + +def _init_worker_ditribute_one_buffer( + recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask +): + + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["wf_array_info"] = wf_array_info + + if mode == "memmap": + filename = wf_array_info + wfs = np.load(str(filename), mmap_mode="r+") + + # in memmap mode we have the "too many open file" problem with linux + # memmap file will be open on demand and not globally per worker + worker_ctx["wf_array_info"] = wf_array_info + elif mode == "shared_memory": + from multiprocessing.shared_memory import SharedMemory + + wfs_arrays = {} + shms = {} + for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + if shm_name is None: + arr = np.zeros(shape=shape, dtype=dtype) + else: + shm = SharedMemory(shm_name) + arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + wfs_arrays[unit_id] = arr + # we need a reference to all sham otherwise we get segment fault!!! + shms[unit_id] = shm + worker_ctx["shms"] = shms + worker_ctx["wfs_arrays"] = wfs_arrays + + worker_ctx["unit_ids"] = unit_ids + worker_ctx["spikes"] = spikes + + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["inds_by_unit"] = inds_by_unit + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode + + return worker_ctx + + +# used by ChunkRecordingExecutor +def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] From bb4457505e22fb5074cf17060391e4e5ce91a80f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 14:37:15 +0200 Subject: [PATCH 04/73] wip waveform tools speedup --- .../core/tests/test_waveform_tools.py | 29 +++- src/spikeinterface/core/waveform_tools.py | 124 ++++++++---------- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index a896ff9c8b..457be5cba4 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,7 +7,7 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, + extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, ) # allocate_waveforms_buffers, distribute_waveforms_to_buffers @@ -64,8 +64,7 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir() - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) + wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -84,8 +83,27 @@ def test_waveform_tools(): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) + + wfs_array = extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + folder=wf_folder, + dtype=dtype, + sparsity_mask=None, + copy=False, + **job_kwargs, + ) + print(wfs_array.shape) + _check_all_wf_equal(list_wfs) + + # memory if platform.system() != "Windows": # shared memory on windows is buggy... @@ -125,9 +143,6 @@ def test_waveform_tools(): sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -144,5 +159,7 @@ def test_waveform_tools(): ) + + if __name__ == "__main__": test_waveform_tools() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a68f8cfd5f..a7c8493381 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -214,6 +214,7 @@ def distribute_waveforms_to_buffers( return_scaled, mode="memmap", sparsity_mask=None, + job_name=None, **job_kwargs, ): """ @@ -272,9 +273,9 @@ def distribute_waveforms_to_buffers( mode, sparsity_mask, ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name=f"extract waveforms {mode} multi buffer" + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -409,6 +410,7 @@ def extract_waveforms_to_unique_buffer( dtype=None, sparsity_mask=None, copy=False, + job_name=None, **job_kwargs, ): @@ -420,11 +422,11 @@ def extract_waveforms_to_unique_buffer( else: folder = Path(folder) - num_spikes = spike.size + num_spikes = spikes.size if sparsity_mask is None: num_chans = recording.get_num_channels() else: - num_chans = np.sum(sparsity_mask[unit_ind, :]) + num_chans = max(np.sum(sparsity_mask, axis=1)) shape = (num_spikes, nsamples, num_chans) if mode == "memmap": @@ -454,7 +456,7 @@ def extract_waveforms_to_unique_buffer( if num_spikes > 0: # and run func = _worker_ditribute_one_buffer - init_func = _init_worker_ditribute_buffers + init_func = _init_worker_ditribute_one_buffer init_args = ( recording, @@ -466,27 +468,27 @@ def extract_waveforms_to_unique_buffer( return_scaled, mode, sparsity_mask, + ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name = f"extract waveforms {mode} mono buffer" + + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() - # if mode == "memmap": - # return wfs_arrays - # elif mode == "shared_memory": - # if copy: - # wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} - # # release all sharedmem buffer - # for unit_id in unit_ids: - # shm = wfs_arrays_info[unit_id][0] - # if shm is not None: - # # empty array have None - # shm.unlink() - # return wfs_arrays - # else: - # return wfs_arrays, wfs_arrays_info + if mode == "memmap": + return wfs_array + elif mode == "shared_memory": + if copy: + wf_array_info = wf_array_info.copy() + if shm is not None: + # release all sharedmem buffer + # empty array have None + shm.unlink() + return wfs_array + else: + return wfs_array, wf_array_info @@ -501,35 +503,29 @@ def _init_worker_ditribute_one_buffer( if mode == "memmap": filename = wf_array_info - wfs = np.load(str(filename), mmap_mode="r+") - - # in memmap mode we have the "too many open file" problem with linux - # memmap file will be open on demand and not globally per worker - worker_ctx["wf_array_info"] = wf_array_info + wfs_array = np.load(str(filename), mmap_mode="r+") + worker_ctx["wfs_array"] = wfs_array elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - - wfs_arrays = {} - shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): - if shm_name is None: - arr = np.zeros(shape=shape, dtype=dtype) - else: - shm = SharedMemory(shm_name) - arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr - # we need a reference to all sham otherwise we get segment fault!!! - shms[unit_id] = shm - worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + shm, shm_name, dtype, shape = wf_array_info + shm = SharedMemory(shm_name) + wfs_array = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + worker_ctx["shm"] = shm + worker_ctx["wfs_array"] = wfs_array + + # prepare segment slices + segment_slices = [] + for segment_index in range(recording.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append((s0, s1)) + worker_ctx["segment_slices"] = segment_slices worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter worker_ctx["return_scaled"] = return_scaled - worker_ctx["inds_by_unit"] = inds_by_unit worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["mode"] = mode @@ -541,20 +537,18 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] + segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"] nafter = worker_ctx["nafter"] return_scaled = worker_ctx["return_scaled"] - inds_by_unit = worker_ctx["inds_by_unit"] sparsity_mask = worker_ctx["sparsity_mask"] + wfs_array = worker_ctx["wfs_array"] seg_size = recording.get_num_samples(segment_index=segment_index) - # take only spikes with the correct segment_index - # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) - in_seg_spikes = spikes[s0:s1] + s0, s1 = segment_slices[segment_index] + in_seg_spikes = spikes[s0: s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! @@ -582,28 +576,18 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for unit_ind, unit_id in enumerate(unit_ids): - # find pos - inds = inds_by_unit[unit_id] - (in_chunk_pos,) = np.nonzero((inds >= l0) & (inds < l1)) - if in_chunk_pos.size == 0: - continue + for spike_ind in range(l0, l1): + sample_index = spikes[spike_ind]["sample_index"] + unit_index = spikes[spike_ind]["unit_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] - if worker_ctx["mode"] == "memmap": - # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] - wfs = np.load(str(filename), mmap_mode="r+") - elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] - - for pos in in_chunk_pos: - sample_index = spikes[inds[pos]]["sample_index"] - wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + if sparsity_mask is None: + wfs_array[spike_ind, :, :] = None + else: + mask = sparsity_mask[unit_index, :] + wf = wf[:, mask] + wfs_array[spike_ind, :, :wf.shape[1]] = wf - if sparsity_mask is None: - wfs[pos, :, :] = wf - else: - wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] def has_exceeding_spikes(recording, sorting): From 5539bdb0028b2bdcebb6b4e6d201f0579770f3e1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 17:02:02 +0200 Subject: [PATCH 05/73] extract_waveforms_to_unique_buffer is more or less OK --- .../core/tests/test_waveform_tools.py | 181 ++++++++---------- src/spikeinterface/core/waveform_tools.py | 59 +++--- 2 files changed, 115 insertions(+), 125 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 457be5cba4..ef75180898 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,8 +7,8 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, -) # allocate_waveforms_buffers, distribute_waveforms_to_buffers + extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, split_waveforms_by_units, +) if hasattr(pytest, "global_test_folder"): @@ -21,6 +21,10 @@ def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): for unit_id in wfs_arrays.keys(): + print() + print('*'*10) + print(wfs_arrays[unit_id].shape) + print(wfs_arrays0[unit_id].shape) assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) @@ -52,111 +56,86 @@ def test_waveform_tools(): unit_ids = sorting.unit_ids some_job_kwargs = [ - {}, {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}, {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] + some_modes = [ + {"mode" : "memmap"}, + ] + if platform.system() != "Windows": + # shared memory on windows is buggy... + some_modes.append({"mode" : "shared_memory", }) + + some_sparsity = [ + dict(sparsity_mask=None), + dict(sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + ] + # memmap mode - list_wfs = [] + list_wfs_dense = [] + list_wfs_sparse = [] for j, job_kwargs in enumerate(some_job_kwargs): - wf_folder = cache_folder / f"test_waveform_tools_{j}" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - - wfs_array = extract_waveforms_to_unique_buffer( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - print(wfs_array.shape) - - _check_all_wf_equal(list_wfs) - - - - # memory - if platform.system() != "Windows": - # shared memory on windows is buggy... - list_wfs = [] - for job_kwargs in some_job_kwargs: - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=dtype, - sparsity_mask=None, - copy=True, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - # to avoid warning we need to first destroy arrays then sharedmemm object - # del wfs_arrays - # del wfs_arrays_info - _check_all_wf_equal(list_wfs) - - # with sparsity - wf_folder = cache_folder / "test_waveform_tools_sparse" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") - job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=sparsity_mask, - copy=False, - **job_kwargs, - ) + for k, mode_kwargs in enumerate(some_modes): + for l, sparsity_kwargs in enumerate(some_sparsity): + + print() + print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + + if mode_kwargs["mode"] == "memmap": + wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + wf_folder.mkdir() + mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) + else: + mode_kwargs_ = mode_kwargs + + wfs_arrays = extract_waveforms_to_buffers( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + for unit_ind, unit_id in enumerate(unit_ids): + wf = wfs_arrays[unit_id] + assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) + + if sparsity_kwargs['sparsity_mask'] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + + all_waveforms = extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + wfs_arrays = split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs['sparsity_mask']) + if sparsity_kwargs['sparsity_mask'] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + _check_all_wf_equal(list_wfs_dense) + _check_all_wf_equal(list_wfs_sparse) + diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a7c8493381..1a2361ec97 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -189,7 +189,7 @@ def allocate_waveforms_buffers( wfs_arrays[unit_id] = arr wfs_arrays_info[unit_id] = filename elif mode == "shared_memory": - if n_spikes == 0: + if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) shm = None shm_name = None @@ -431,15 +431,15 @@ def extract_waveforms_to_unique_buffer( if mode == "memmap": filename = str(folder / f"all_waveforms.npy") - wfs_array = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) + all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) wf_array_info = filename elif mode == "shared_memory": - if num_spikes == 0: - wfs_array = np.zeros(shape, dtype=dtype) + if num_spikes == 0 or num_chans == 0: + all_waveforms = np.zeros(shape, dtype=dtype) shm = None shm_name = None else: - wfs_array, shm = make_shared_array(shape, dtype) + all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name wf_array_info = (shm, shm_name, dtype.str, shape) else: @@ -478,17 +478,16 @@ def extract_waveforms_to_unique_buffer( if mode == "memmap": - return wfs_array + return all_waveforms elif mode == "shared_memory": if copy: - wf_array_info = wf_array_info.copy() if shm is not None: # release all sharedmem buffer # empty array have None shm.unlink() - return wfs_array + return all_waveforms.copy() else: - return wfs_array, wf_array_info + return all_waveforms, wf_array_info @@ -500,18 +499,25 @@ def _init_worker_ditribute_one_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info + worker_ctx["unit_ids"] = unit_ids + worker_ctx["spikes"] = spikes + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode if mode == "memmap": filename = wf_array_info - wfs_array = np.load(str(filename), mmap_mode="r+") - worker_ctx["wfs_array"] = wfs_array + all_waveforms = np.load(str(filename), mmap_mode="r+") + worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory shm, shm_name, dtype, shape = wf_array_info shm = SharedMemory(shm_name) - wfs_array = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm - worker_ctx["wfs_array"] = wfs_array + worker_ctx["all_waveforms"] = all_waveforms # prepare segment slices segment_slices = [] @@ -521,14 +527,6 @@ def _init_worker_ditribute_one_buffer( segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices - worker_ctx["unit_ids"] = unit_ids - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode - return worker_ctx @@ -543,7 +541,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c nafter = worker_ctx["nafter"] return_scaled = worker_ctx["return_scaled"] sparsity_mask = worker_ctx["sparsity_mask"] - wfs_array = worker_ctx["wfs_array"] + all_waveforms = worker_ctx["all_waveforms"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -582,12 +580,25 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - wfs_array[spike_ind, :, :] = None + all_waveforms[spike_ind, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - wfs_array[spike_ind, :, :wf.shape[1]] = wf + all_waveforms[spike_ind, :, :wf.shape[1]] = wf + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): + waveform_by_units = {} + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes["unit_index"] == unit_index + if sparsity_mask is not None: + chan_mask = sparsity_mask[unit_index, :] + num_chans = np.sum(chan_mask) + waveform_by_units[unit_id] = all_waveforms[mask, :, :][:, :, :num_chans] + else: + waveform_by_units[unit_id] = all_waveforms[mask, :, :] + return waveform_by_units def has_exceeding_spikes(recording, sorting): From ec86e2987414dc060ff657531dd3be3b475bf49b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 17:02:34 +0200 Subject: [PATCH 06/73] clean --- src/spikeinterface/core/tests/test_waveform_tools.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index ef75180898..c86aa6d5d7 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -21,10 +21,6 @@ def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): for unit_id in wfs_arrays.keys(): - print() - print('*'*10) - print(wfs_arrays[unit_id].shape) - print(wfs_arrays0[unit_id].shape) assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) @@ -79,8 +75,8 @@ def test_waveform_tools(): for k, mode_kwargs in enumerate(some_modes): for l, sparsity_kwargs in enumerate(some_sparsity): - print() - print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + # print() + # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) if mode_kwargs["mode"] == "memmap": wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" From da729088eb539991dc315c66a634ac23cc0f136f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:03:10 +0000 Subject: [PATCH 07/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/tests/test_waveform_tools.py | 32 ++++++++++--------- src/spikeinterface/core/waveform_tools.py | 14 +++----- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index c86aa6d5d7..9a51a10ee2 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,7 +7,9 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, split_waveforms_by_units, + extract_waveforms_to_buffers, + extract_waveforms_to_unique_buffer, + split_waveforms_by_units, ) @@ -56,17 +58,20 @@ def test_waveform_tools(): {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] some_modes = [ - {"mode" : "memmap"}, + {"mode": "memmap"}, ] if platform.system() != "Windows": # shared memory on windows is buggy... - some_modes.append({"mode" : "shared_memory", }) + some_modes.append( + { + "mode": "shared_memory", + } + ) some_sparsity = [ dict(sparsity_mask=None), - dict(sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), ] - # memmap mode list_wfs_dense = [] @@ -74,7 +79,6 @@ def test_waveform_tools(): for j, job_kwargs in enumerate(some_job_kwargs): for k, mode_kwargs in enumerate(some_modes): for l, sparsity_kwargs in enumerate(some_sparsity): - # print() # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) @@ -86,7 +90,7 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) else: mode_kwargs_ = mode_kwargs - + wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -103,13 +107,12 @@ def test_waveform_tools(): for unit_ind, unit_id in enumerate(unit_ids): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - - if sparsity_kwargs['sparsity_mask'] is None: + + if sparsity_kwargs["sparsity_mask"] is None: list_wfs_dense.append(wfs_arrays) else: list_wfs_sparse.append(wfs_arrays) - all_waveforms = extract_waveforms_to_unique_buffer( recording, spikes, @@ -123,8 +126,10 @@ def test_waveform_tools(): **mode_kwargs_, **job_kwargs, ) - wfs_arrays = split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs['sparsity_mask']) - if sparsity_kwargs['sparsity_mask'] is None: + wfs_arrays = split_waveforms_by_units( + unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"] + ) + if sparsity_kwargs["sparsity_mask"] is None: list_wfs_dense.append(wfs_arrays) else: list_wfs_sparse.append(wfs_arrays) @@ -133,8 +138,5 @@ def test_waveform_tools(): _check_all_wf_equal(list_wfs_sparse) - - - if __name__ == "__main__": test_waveform_tools() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 1a2361ec97..e6f7e944cc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -274,7 +274,7 @@ def distribute_waveforms_to_buffers( sparsity_mask, ) if job_name is None: - job_name=f"extract waveforms {mode} multi buffer" + job_name = f"extract waveforms {mode} multi buffer" processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -413,7 +413,6 @@ def extract_waveforms_to_unique_buffer( job_name=None, **job_kwargs, ): - nsamples = nbefore + nafter dtype = np.dtype(dtype) @@ -445,7 +444,6 @@ def extract_waveforms_to_unique_buffer( else: raise ValueError("allocate_waveforms_buffers bad mode") - job_kwargs = fix_job_kwargs(job_kwargs) inds_by_unit = {} @@ -468,7 +466,6 @@ def extract_waveforms_to_unique_buffer( return_scaled, mode, sparsity_mask, - ) if job_name is None: job_name = f"extract waveforms {mode} mono buffer" @@ -476,7 +473,6 @@ def extract_waveforms_to_unique_buffer( processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() - if mode == "memmap": return all_waveforms elif mode == "shared_memory": @@ -490,12 +486,9 @@ def extract_waveforms_to_unique_buffer( return all_waveforms, wf_array_info - - def _init_worker_ditribute_one_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): - worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info @@ -513,6 +506,7 @@ def _init_worker_ditribute_one_buffer( worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory + shm, shm_name, dtype, shape = wf_array_info shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) @@ -546,7 +540,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c seg_size = recording.get_num_samples(segment_index=segment_index) s0, s1 = segment_slices[segment_index] - in_seg_spikes = spikes[s0: s1] + in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! @@ -584,7 +578,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, :wf.shape[1]] = wf + all_waveforms[spike_ind, :, : wf.shape[1]] = wf def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): From 45036bf72c84376edeebb531f4feebeb8f0255e2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 10 Jul 2023 14:30:34 +0200 Subject: [PATCH 08/73] wip waveforms tools single buffer --- .../core/tests/test_waveform_tools.py | 4 +- src/spikeinterface/core/waveform_tools.py | 173 +++++++++++++----- 2 files changed, 131 insertions(+), 46 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 9a51a10ee2..fb65e87458 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -8,7 +8,7 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, - extract_waveforms_to_unique_buffer, + extract_waveforms_to_single_buffer, split_waveforms_by_units, ) @@ -113,7 +113,7 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) - all_waveforms = extract_waveforms_to_unique_buffer( + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, unit_ids, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index e6f7e944cc..252ea68738 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -36,7 +36,7 @@ def extract_waveforms_to_buffers( Same as calling allocate_waveforms_buffers() and then distribute_waveforms_to_buffers(). - Important note: for the "shared_memory" mode wfs_arrays_info contains reference to + Important note: for the "shared_memory" mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. To avoid this a copy to non shared memmory can be perform at the end. @@ -66,17 +66,17 @@ def extract_waveforms_to_buffers( If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then wfs_arrays_info is also return. Please keep in mind that wfs_arrays_info - need to be referenced as long as wfs_arrays will be used otherwise it will be very hard to debug. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. Also when copy=False the SharedMemory will need to be unlink manually {} Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep) - wfs_arrays_info: dict of info + arrays_info: dict of info Optionally return in case of shared_memory if copy=False. Dictionary to "construct" array in workers process (memmap file or sharemem info) """ @@ -89,7 +89,7 @@ def extract_waveforms_to_buffers( dtype = "float32" dtype = np.dtype(dtype) - wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers( + waveforms_by_units, arrays_info = allocate_waveforms_buffers( recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask ) @@ -97,7 +97,7 @@ def extract_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -107,19 +107,19 @@ def extract_waveforms_to_buffers( ) if mode == "memmap": - return wfs_arrays + return waveforms_by_units elif mode == "shared_memory": if copy: - wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + waveforms_by_units = {unit_id: arr.copy() for unit_id, arr in waveforms_by_units.items()} # release all sharedmem buffer for unit_id in unit_ids: - shm = wfs_arrays_info[unit_id][0] + shm = arrays_info[unit_id][0] if shm is not None: # empty array have None shm.unlink() - return wfs_arrays + return waveforms_by_units else: - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info extract_waveforms_to_buffers.__doc__ = extract_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) @@ -131,7 +131,7 @@ def allocate_waveforms_buffers( """ Allocate memmap or shared memory buffers before snippet extraction. - Important note: for the shared memory mode wfs_arrays_info contains reference to + Important note: for the shared memory mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -158,9 +158,9 @@ def allocate_waveforms_buffers( Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep - wfs_arrays_info: dict of info + arrays_info: dict of info Dictionary to "construct" array in workers process (memmap file or sharemem) """ @@ -173,8 +173,8 @@ def allocate_waveforms_buffers( folder = Path(folder) # prepare buffers - wfs_arrays = {} - wfs_arrays_info = {} + waveforms_by_units = {} + arrays_info = {} for unit_ind, unit_id in enumerate(unit_ids): n_spikes = np.sum(spikes["unit_index"] == unit_ind) if sparsity_mask is None: @@ -186,8 +186,8 @@ def allocate_waveforms_buffers( if mode == "memmap": filename = str(folder / f"waveforms_{unit_id}.npy") arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = filename + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = filename elif mode == "shared_memory": if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) @@ -196,19 +196,19 @@ def allocate_waveforms_buffers( else: arr, shm = make_shared_array(shape, dtype) shm_name = shm.name - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) else: raise ValueError("allocate_waveforms_buffers bad mode") - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info def distribute_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -222,7 +222,7 @@ def distribute_waveforms_to_buffers( Buffers must be pre-allocated with the `allocate_waveforms_buffers()` function. - Important note, for "shared_memory" mode wfs_arrays_info contain reference to + Important note, for "shared_memory" mode arrays_info contain reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -234,7 +234,7 @@ def distribute_waveforms_to_buffers( This vector can be spikes = Sorting.to_spike_vector() unit_ids: list ot numpy List of unit_ids - wfs_arrays_info: dict + arrays_info: dict Dictionary to "construct" array in workers process (memmap file or sharemem) nbefore: int N samples before spike @@ -265,7 +265,7 @@ def distribute_waveforms_to_buffers( recording, unit_ids, spikes, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -284,7 +284,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor def _init_worker_ditribute_buffers( - recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask + recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker worker_ctx = {} @@ -297,23 +297,23 @@ def _init_worker_ditribute_buffers( if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["wfs_arrays_info"] = wfs_arrays_info + worker_ctx["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - wfs_arrays = {} + waveforms_by_units = {} shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + for unit_id, (shm, shm_name, dtype, shape) in arrays_info.items(): if shm_name is None: arr = np.zeros(shape=shape, dtype=dtype) else: shm = SharedMemory(shm_name) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr + waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + worker_ctx["waveforms_by_units"] = waveforms_by_units worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes @@ -383,10 +383,10 @@ def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx) if worker_ctx["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] + filename = worker_ctx["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] + wfs = worker_ctx["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -398,7 +398,7 @@ def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx) wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] -def extract_waveforms_to_unique_buffer( +def extract_waveforms_to_single_buffer( recording, spikes, unit_ids, @@ -413,6 +413,57 @@ def extract_waveforms_to_unique_buffer( job_name=None, **job_kwargs, ): + """ + Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. + + Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channel are not aligned across + units. + + Important note: for the "shared_memory" mode wf_array_info contains reference to + the shared memmory buffer, this variable must be reference as long as arrays as used. + And this variable is also returned. + To avoid this a copy to non shared memmory can be perform at the end. + + Parameters + ---------- + recording: recording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = Sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + N samples before spike + nafter: int + N samples after spike + mode: str + Mode to use ('memmap' | 'shared_memory') + return_scaled: bool + Scale traces before exporting to buffer or not. + folder: str or path + In case of memmap mode, folder to save npy files + dtype: numpy.dtype + dtype for waveforms buffer + sparsity_mask: None or array of bool + If not None shape must be must be (len(unit_ids), len(channel_ids)) + copy: bool + If True (default), the output shared memory object is copied to a numpy standard array. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually + {} + + Returns + ------- + all_waveforms: numpy array + Single array with shape (nump_spikes, num_samples, num_channels) + + wf_array_info: dict of info + Optionally return in case of shared_memory if copy=False. + Dictionary to "construct" array in workers process (memmap file or sharemem info) + """ nsamples = nbefore + nafter dtype = np.dtype(dtype) @@ -453,8 +504,8 @@ def extract_waveforms_to_unique_buffer( if num_spikes > 0: # and run - func = _worker_ditribute_one_buffer - init_func = _init_worker_ditribute_one_buffer + func = _worker_ditribute_single_buffer + init_func = _init_worker_ditribute_single_buffer init_args = ( recording, @@ -486,7 +537,7 @@ def extract_waveforms_to_unique_buffer( return all_waveforms, wf_array_info -def _init_worker_ditribute_one_buffer( +def _init_worker_ditribute_single_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} @@ -525,7 +576,7 @@ def _init_worker_ditribute_one_buffer( # used by ChunkRecordingExecutor -def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_ditribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -580,19 +631,53 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c wf = wf[:, mask] all_waveforms[spike_ind, :, : wf.shape[1]] = wf + if worker_ctx["mode"] == "memmap": + all_waveforms.flush() + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None, folder=None): + """ + Split a single buffer waveforms into waveforms by units (multi buffers or multi files). + + Parameters + ---------- + unit_ids: list or numpy array + List of unit ids + spikes: numpy array + The spike vector + all_waveforms : numpy array + Single buffer containing all waveforms + sparsity_mask : None or numpy array + Optionally the boolean sparsity mask + folder : None or str or Path + If a folde ri sgiven all -def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): - waveform_by_units = {} + Returns + ------- + waveforms_by_units: dict of array + A dict of arrays. + In case of folder not None, this contain the memmap of the files. + """ + if folder is not None: + folder = Path(folder) + waveforms_by_units = {} for unit_index, unit_id in enumerate(unit_ids): mask = spikes["unit_index"] == unit_index if sparsity_mask is not None: chan_mask = sparsity_mask[unit_index, :] num_chans = np.sum(chan_mask) - waveform_by_units[unit_id] = all_waveforms[mask, :, :][:, :, :num_chans] + wfs = all_waveforms[mask, :, :][:, :, :num_chans] + else: + wfs = all_waveforms[mask, :, :] + + if folder is None: + waveforms_by_units[unit_id] = wfs else: - waveform_by_units[unit_id] = all_waveforms[mask, :, :] + np.save(folder / f"waveforms_{unit_id}.npy", wfs) + # this avoid keeping in memory all waveforms + waveforms_by_units[unit_id] = np.load(f"waveforms_{unit_id}.npy", mmap_mode="r") - return waveform_by_units + return waveforms_by_units def has_exceeding_spikes(recording, sorting): From 570a3a5fa8dc1b340da4ee0fc6baa16013213bdb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 26 Jul 2023 17:57:40 +0200 Subject: [PATCH 09/73] fix local tests --- src/spikeinterface/core/tests/test_waveform_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index fb65e87458..52d7472c92 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -86,7 +86,7 @@ def test_waveform_tools(): wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" if wf_folder.is_dir(): shutil.rmtree(wf_folder) - wf_folder.mkdir() + wf_folder.mkdir(parents=True) mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) else: mode_kwargs_ = mode_kwargs From f737f5a4b4b998321ff2999ec57915f1c1b6dc82 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 24 Aug 2023 15:14:55 +0200 Subject: [PATCH 10/73] wip collisions --- .../postprocessing/amplitude_scalings.py | 384 +++++++++++++++++- 1 file changed, 370 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3ebeafcfec..7539e4d0b7 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -22,8 +22,25 @@ def __init__(self, waveform_extractor): extremum_channel_inds=extremum_channel_inds, use_cache=False ) - def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): - params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + def _set_params( + self, + sparsity, + max_dense_channels, + ms_before, + ms_after, + handle_collisions, + max_consecutive_collisions, + delta_collision_ms, + ): + params = dict( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + max_consecutive_collisions=max_consecutive_collisions, + delta_collision_ms=delta_collision_ms, + ) return params def _select_extension_data(self, unit_ids): @@ -43,6 +60,12 @@ def _run(self, **job_kwargs): ms_before = self._params["ms_before"] ms_after = self._params["ms_after"] + # collisions + handle_collisions = self._params["handle_collisions"] + max_consecutive_collisions = self._params["max_consecutive_collisions"] + delta_collision_ms = self._params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) + return_scaled = we._params["return_scaled"] unit_ids = we.unit_ids @@ -67,6 +90,8 @@ def _run(self, **job_kwargs): assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) sparsity_inds = sparsity.unit_id_to_channel_indices + + # easier to use in chunk function as spikes use unit_index instead o id unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} all_templates = we.get_all_templates() @@ -93,6 +118,9 @@ def _run(self, **job_kwargs): cut_out_before, cut_out_after, return_scaled, + handle_collisions, + max_consecutive_collisions, + delta_collision_samples, ) processor = ChunkRecordingExecutor( recording, @@ -154,6 +182,9 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, + handle_collisions=False, + max_consecutive_collisions=3, + delta_collision_ms=2, load_if_exists=False, outputs="concatenated", **job_kwargs, @@ -165,22 +196,29 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity + sparsity: ChannelSparsity, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - By default None max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, optional + ms_before : float, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used, by default None - ms_after : float, optional + If None, the WaveformExtractor ms_before is used. + ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used, by default None + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: False + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + max_consecutive_collisions: int, default: 3 + The maximum number of consecutive collisions to handle on each side of a spike. + delta_collision_ms: float, default: 2 + The maximum time difference in ms between two spikes to be considered as colliding. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str + outputs: str, default: 'concatenated' How the output should be returned: - 'concatenated' - 'by_unit' @@ -197,7 +235,15 @@ def compute_amplitude_scalings( sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) else: sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + sac.set_params( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + max_consecutive_collisions=max_consecutive_collisions, + delta_collision_ms=delta_collision_ms, + ) sac.run(**job_kwargs) amps = sac.get_data(outputs=outputs) @@ -218,6 +264,9 @@ def _init_worker_amplitude_scalings( cut_out_before, cut_out_after, return_scaled, + handle_collisions, + max_consecutive_collisions, + delta_collision_samples, ): # create a local dict per worker worker_ctx = {} @@ -229,9 +278,18 @@ def _init_worker_amplitude_scalings( worker_ctx["nafter"] = nafter worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["margin"] = max(nbefore, nafter) worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["handle_collisions"] = handle_collisions + worker_ctx["max_consecutive_collisions"] = max_consecutive_collisions + worker_ctx["delta_collision_samples"] = delta_collision_samples + + if not handle_collisions: + worker_ctx["margin"] = max(nbefore, nafter) + else: + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = int(max_consecutive_collisions * delta_collision_samples) + worker_ctx["margin"] = max(margin_waveforms, max_margin_collisions) return worker_ctx @@ -250,6 +308,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after = worker_ctx["cut_out_after"] margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] + handle_collisions = worker_ctx["handle_collisions"] + max_consecutive_collisions = worker_ctx["max_consecutive_collisions"] + delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] @@ -272,8 +333,24 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) offsets = recording.get_property("offset_to_uV") traces_with_margin = traces_with_margin.astype("float32") * gains + offsets + # set colliding spikes apart (if needed) + if handle_collisions: + overlapping_mask = _find_overlapping_mask( + local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices + ) + overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] + print( + f"Found {len(overlapping_spike_indices)} overlapping spikes in segment {segment_index}! - chunk {start_frame} - {end_frame}" + ) + else: + overlapping_spike_indices = np.array([], dtype=int) + # get all waveforms - for spike in local_spikes: + scalings = np.zeros(len(local_spikes), dtype=float) + for spike_index, spike in enumerate(local_spikes): + if spike_index in overlapping_spike_indices: + # we deal with overlapping spikes later + continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] sparse_indices = unit_inds_to_channel_indices[unit_index] @@ -294,7 +371,286 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) local_waveforms.append(local_waveform) templates.append(template) linregress_res = linregress(template.flatten(), local_waveform.flatten()) - scalings.append(linregress_res[0]) - scalings = np.array(scalings) + scalings[spike_index] = linregress_res[0] + + # deal with collisions + if len(overlapping_spike_indices) > 0: + for overlapping in overlapping_mask: + spike_index = overlapping[max_consecutive_collisions] + overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] + scaled_amps = _fit_collision( + overlapping_spikes, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, + ) + # get the right amplitude scaling + scalings[spike_index] = scaled_amps[np.where(overlapping >= 0)[0] == max_consecutive_collisions] return (scalings,) + + +### Collision handling ### +def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): + """ + Returns True if the unit indices i and j are overlapping, False otherwise + + Parameters + ---------- + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + i: int + The first unit index + j: int + The second unit index + + Returns + ------- + bool + True if the unit indices i and j are overlapping, False otherwise + """ + if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + return True + else: + return False + + +def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): + """ + Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape + (n_spikes, 2 * max_consecutive_spikes + 1). + + Parameters + ---------- + spikes: np.array + An array of spikes + max_consecutive_spikes: int + The maximum number of consecutive spikes to consider + delta_overlap_samples: int + The maximum number of samples between two spikes to consider them as overlapping + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + + Returns + ------- + overlapping_mask: np.array + A boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1) where the central column (max_consecutive_spikes) + is the current spike index, while the other columns are the indices of the overlapping spikes. The first + max_consecutive_spikes columns are the pre-overlapping spikes, while the last max_consecutive_spikes columns are + the post-overlapping spikes. + """ + + # overlapping_mask is a matrix of shape (n_spikes, 2 * max_consecutive_spikes + 1) + # the central column (max_consecutive_spikes) is the current spike index, while the other columns are the + # indices of the overlapping spikes. The first max_consecutive_spikes columns are the pre-overlapping spikes, + # while the last max_consecutive_spikes columns are the post-overlapping spikes + # Rows with all -1 are non-colliding spikes and are removed later + overlapping_mask_full = -1 * np.ones((len(spikes), 2 * max_consecutive_spikes + 1), dtype=int) + overlapping_mask_full[:, max_consecutive_spikes] = np.arange(len(spikes)) + + for i, spike in enumerate(spikes): + # find the possible spikes per and post within max_consecutive_spikes * delta_overlap_samples + consecutive_window_pre = np.searchsorted( + spikes["sample_index"], + spike["sample_index"] - max_consecutive_spikes * delta_overlap_samples, + ) + consecutive_window_post = np.searchsorted( + spikes["sample_index"], + spike["sample_index"] + max_consecutive_spikes * delta_overlap_samples, + ) + pre_possible_consecutive_spikes = spikes[consecutive_window_pre:i][::-1] + post_possible_consecutive_spikes = spikes[i + 1 : consecutive_window_post] + + # here we fill in the overlapping information by consecutively looping through the possible consecutive spikes + # and checking the spatial overlap and the delay with the previous overlapping spike + # pre and post are hanlded separately. Note that the pre-spikes are already sorted backwards + + # overlap_rank keeps track of the rank of consecutive collisions (i.e., rank 0 is the first, rank 1 is the second, etc.) + # this is needed because we are just considering spikes with spatial overlap, while the possible consecutive spikes + # only looked at the temporal overlap + overlap_rank = 0 + if len(pre_possible_consecutive_spikes) > 0: + for c_pre, spike_consecutive_pre in enumerate(pre_possible_consecutive_spikes[::-1]): + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_pre["unit_index"] + ): + if ( + spikes[overlapping_mask_full[i, max_consecutive_spikes - overlap_rank]]["sample_index"] + - spike_consecutive_pre["sample_index"] + < delta_overlap_samples + ): + overlapping_mask_full[i, max_consecutive_spikes - overlap_rank - 1] = i - 1 - c_pre + overlap_rank += 1 + else: + break + # if overlap_rank > 1: + # print(f"\tHigher order pre-overlap for spike {i}!") + + overlap_rank = 0 + if len(post_possible_consecutive_spikes) > 0: + for c_post, spike_consecutive_post in enumerate(post_possible_consecutive_spikes): + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_post["unit_index"] + ): + if ( + spike_consecutive_post["sample_index"] + - spikes[overlapping_mask_full[i, max_consecutive_spikes + overlap_rank]]["sample_index"] + < delta_overlap_samples + ): + overlapping_mask_full[i, max_consecutive_spikes + overlap_rank + 1] = i + 1 + c_post + overlap_rank += 1 + else: + break + # if overlap_rank > 1: + # print(f"\tHigher order post-overlap for spike {i}!") + + # in case no collisions were found, we set the central column to -1 so that we can easily identify the non-colliding spikes + if np.sum(overlapping_mask_full[i] != -1) == 1: + overlapping_mask_full[i, max_consecutive_spikes] = -1 + + # only return rows with collisions + overlapping_inds = [] + for i, overlapping in enumerate(overlapping_mask_full): + if np.any(overlapping >= 0): + overlapping_inds.append(i) + overlapping_mask = overlapping_mask_full[overlapping_inds] + + return overlapping_mask + + +def _fit_collision( + overlapping_spikes, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, +): + """ """ + from sklearn.linear_model import LinearRegression + + sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left + sample_last_centered = overlapping_spikes[-1]["sample_index"] - start_frame - left + + # construct sparsity as union between units' sparsity + sparse_indices = np.array([], dtype="int") + for spike in overlapping_spikes: + sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + + local_waveform_start = max(0, sample_first_centered - cut_out_before) + local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) + local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + + y = local_waveform.T.flatten() + X = np.zeros((len(y), len(overlapping_spikes))) + for i, spike in enumerate(overlapping_spikes): + full_template = np.zeros_like(local_waveform) + # center wrt cutout traces + sample_centered = spike["sample_index"] - local_waveform_start + template = all_templates[spike["unit_index"]][:, sparse_indices] + template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] + # deal with borders + if sample_centered - cut_out_before < 0: + full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > end_frame + right: + full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + else: + full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut + X[:, i] = full_template.T.flatten() + + reg = LinearRegression().fit(X, y) + amps = reg.coef_ + return amps + + +# TODO: fix this! +# def plot_overlapping_spikes(we, overlap, +# spikes, cut_out_samples=100, +# max_consecutive_spikes=3, +# sparsity=None, +# fitted_amps=None): +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) +# spike_index = overlap[max_consecutive_spikes] +# overlap_indices = overlap[overlap != -1] +# overlapping_spikes = spikes[overlap_indices] + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = spikes[spike_index]["sample_index"] +# max_delta = np.max([np.abs(center_spike - overlapping_spikes[0]["sample_index"]), +# np.abs(center_spike - overlapping_spikes[-1]["sample_index"])]) +# sf = center_spike - max_delta - cut_out_samples +# ef = center_spike + max_delta + cut_out_samples +# tr_overlap = recording.get_traces(start_frame=sf, +# end_frame=ef, +# channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) +# fig, ax = plt.subplots() +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for spike in overlapping_spikes: +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline(spike["sample_index"] / recording.sampling_frequency * 1000, +# color=f"C{spike['unit_index']}", label=label) + +# if fitted_amps is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# template_scaled = fitted_amps[overlap_indices[i]] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf: sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): + +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", +# ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike}") +# return tr_overlap, ax From b9391a69c26f027e40da7cf0c3b7cffbf68b2d5e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 28 Aug 2023 09:55:29 +0200 Subject: [PATCH 11/73] wip collisions --- .../postprocessing/amplitude_scalings.py | 301 ++++++++++++------ 1 file changed, 203 insertions(+), 98 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 7539e4d0b7..d367ef4f22 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -21,6 +21,7 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) + self.overlapping_mask = None def _set_params( self, @@ -132,8 +133,30 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings,) = zip(*out) + (amp_scalings, overlapping_mask) = zip(*out) amp_scalings = np.concatenate(amp_scalings) + if handle_collisions > 0: + from ..core.job_tools import divide_recording_into_chunks + + overlapping_mask_corrected = [] + all_chunks = divide_recording_into_chunks(processor.recording, processor.chunk_size) + num_spikes_so_far = 0 + for i, overlapping in enumerate(overlapping_mask): + if i == 0: + continue + segment_index = all_chunks[i - 1][0] + spikes_in_segment = self.spikes[segment_slices[segment_index]] + i0 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][1]) + i1 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][2]) + num_spikes_so_far += i1 - i0 + overlapping_corrected = overlapping.copy() + overlapping_corrected[overlapping_corrected >= 0] += num_spikes_so_far + overlapping_mask_corrected.append(overlapping_corrected) + overlapping_mask = np.concatenate(overlapping_mask_corrected) + print(f"Found {len(overlapping_mask)} overlapping spikes") + self.overlapping_mask = overlapping_mask + else: + overlapping_mask = np.concatenate(overlapping_mask) self._extension_data[f"amplitude_scalings"] = amp_scalings @@ -314,13 +337,10 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] + # TODO: handle spikes in margin! i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) - local_waveforms = [] - templates = [] - scalings = [] - if i0 != i1: local_spikes = spikes_in_segment[i0:i1] traces_with_margin, left, right = get_chunk_with_margin( @@ -335,13 +355,10 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: - overlapping_mask = _find_overlapping_mask( + overlapping_mask = find_overlapping_mask( local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices ) overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] - print( - f"Found {len(overlapping_spike_indices)} overlapping spikes in segment {segment_index}! - chunk {start_frame} - {end_frame}" - ) else: overlapping_spike_indices = np.array([], dtype=int) @@ -368,17 +385,18 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape - local_waveforms.append(local_waveform) - templates.append(template) + linregress_res = linregress(template.flatten(), local_waveform.flatten()) scalings[spike_index] = linregress_res[0] # deal with collisions if len(overlapping_spike_indices) > 0: for overlapping in overlapping_mask: + # the current spike is the one at the 'max_consecutive_collisions' position spike_index = overlapping[max_consecutive_collisions] overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] - scaled_amps = _fit_collision( + current_spike_index_within_overlapping = np.where(overlapping >= 0)[0] == max_consecutive_collisions + scaled_amps = fit_collision( overlapping_spikes, traces_with_margin, start_frame, @@ -392,9 +410,12 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after, ) # get the right amplitude scaling - scalings[spike_index] = scaled_amps[np.where(overlapping >= 0)[0] == max_consecutive_collisions] + scalings[spike_index] = scaled_amps[current_spike_index_within_overlapping] + else: + scalings = np.array([]) + overlapping_mask = np.array([], shape=(0, max_consecutive_collisions + 1)) - return (scalings,) + return (scalings, overlapping_mask) ### Collision handling ### @@ -422,7 +443,7 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): return False -def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): +def find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): """ Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1). @@ -525,7 +546,7 @@ def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples return overlapping_mask -def _fit_collision( +def fit_collision( overlapping_spikes, traces_with_margin, start_frame, @@ -537,8 +558,41 @@ def _fit_collision( unit_inds_to_channel_indices, cut_out_before, cut_out_after, + debug=True, ): - """ """ + """ + Compute the best fit for a collision between a spike and its overlapping spikes. + + Parameters + ---------- + overlapping_spikes: np.ndarray + A numpy array of shape (n_overlapping_spikes, ) containing the overlapping spikes (spike_dtype). + traces_with_margin: np.ndarray + A numpy array of shape (n_samples, n_channels) containing the traces with a margin. + start_frame: int + The start frame of the chunk for traces_with_margin. + end_frame: int + The end frame of the chunk for traces_with_margin. + left: int + The left margin of the chunk for traces_with_margin. + right: int + The right margin of the chunk for traces_with_margin. + nbefore: int + The number of samples before the spike to consider for the fit. + all_templates: np.ndarray + A numpy array of shape (n_units, n_samples, n_channels) containing the templates. + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices. + cut_out_before: int + The number of samples to cut out before the spike. + cut_out_after: int + The number of samples to cut out after the spike. + + Returns + ------- + np.ndarray + The fitted scaling factors for the overlapping spikes. + """ from sklearn.linear_model import LinearRegression sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left @@ -550,6 +604,7 @@ def _fit_collision( sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + # TODO: check alignment!!! local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] @@ -559,7 +614,7 @@ def _fit_collision( for i, spike in enumerate(overlapping_spikes): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - local_waveform_start + sample_centered = spike["sample_index"] - start_frame - left - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -571,86 +626,136 @@ def _fit_collision( full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() + if debug: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + max_tr = np.max(np.abs(local_waveform)) + + _ = ax.plot(y, color="k") + + for i, spike in enumerate(overlapping_spikes): + _ = ax.plot(X[:, i], color=f"C{i}", alpha=0.5) + plt.show() + reg = LinearRegression().fit(X, y) - amps = reg.coef_ - return amps + scalings = reg.coef_ + return scalings + + +def plot_collisions(we, sparsity=None, num_collisions=None): + """ + Plot the fitting of collision spikes. + + Parameters + ---------- + we : WaveformExtractor + The WaveformExtractor object. + sparsity : ChannelSparsity, default=None + The ChannelSparsity. If None, only main channels are plotted. + num_collisions : int, default=None + Number of collisions to plot. If None, all collisions are plotted. + """ + assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" + sac = we.load_extension("amplitude_scalings") + handle_collisions = sac._params["handle_collisions"] + assert handle_collisions, "Amplitude scalings was run without handling collisions!" + scalings = sac.get_data() + + overlapping_mask = sac.overlapping_mask + num_collisions = num_collisions or len(overlapping_mask) + spikes = sac.spikes + max_consecutive_collisions = sac._params["max_consecutive_collisions"] + + for i in range(num_collisions): + ax = _plot_one_collision( + we, overlapping_mask[i], spikes, scalings=scalings, max_consecutive_collisions=max_consecutive_collisions + ) + + +def _plot_one_collision( + we, + overlap, + spikes, + scalings=None, + sparsity=None, + cut_out_samples=100, + max_consecutive_collisions=3, +): + import matplotlib.pyplot as plt + + recording = we.recording + nbefore_nafter_max = max(we.nafter, we.nbefore) + cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + spike_index = overlap[max_consecutive_collisions] + overlap_indices = overlap[overlap != -1] + overlapping_spikes = spikes[overlap_indices] + + if sparsity is not None: + unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices + sparse_indices = np.array([], dtype="int") + for spike in overlapping_spikes: + sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + else: + sparse_indices = np.unique(overlapping_spikes["channel_index"]) + + channel_ids = recording.channel_ids[sparse_indices] + + center_spike = spikes[spike_index] + max_delta = np.max( + [ + np.abs(center_spike["sample_index"] - overlapping_spikes[0]["sample_index"]), + np.abs(center_spike["sample_index"] - overlapping_spikes[-1]["sample_index"]), + ] + ) + sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) + ef = min( + center_spike["sample_index"] + max_delta + cut_out_samples, + recording.get_num_samples(segment_index=center_spike["segment_index"]), + ) + tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) + ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 + max_tr = np.max(np.abs(tr_overlap)) + fig, ax = plt.subplots() + for ch, tr in enumerate(tr_overlap.T): + _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") + ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + + used_labels = [] + for spike in overlapping_spikes: + label = f"U{spike['unit_index']}" + if label in used_labels: + label = None + else: + used_labels.append(label) + ax.axvline( + spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label + ) + + if scalings is not None: + fitted_traces = np.zeros_like(tr_overlap) + + all_templates = we.get_all_templates() + for i, spike in enumerate(overlapping_spikes): + template = all_templates[spike["unit_index"]] + template_scaled = scalings[overlap_indices[i]] * template + template_scaled_sparse = template_scaled[:, sparse_indices] + sample_start = spike["sample_index"] - we.nbefore + sample_end = sample_start + template_scaled_sparse.shape[0] + + fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + + for ch, temp in enumerate(template_scaled_sparse.T): + ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 + _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + + for ch, tr in enumerate(fitted_traces.T): + _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + fitted_line = ax.get_lines()[-1] + fitted_line.set_label("Fitted") -# TODO: fix this! -# def plot_overlapping_spikes(we, overlap, -# spikes, cut_out_samples=100, -# max_consecutive_spikes=3, -# sparsity=None, -# fitted_amps=None): -# recording = we.recording -# nbefore_nafter_max = max(we.nafter, we.nbefore) -# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) -# spike_index = overlap[max_consecutive_spikes] -# overlap_indices = overlap[overlap != -1] -# overlapping_spikes = spikes[overlap_indices] - -# if sparsity is not None: -# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices -# sparse_indices = np.array([], dtype="int") -# for spike in overlapping_spikes: -# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] -# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) -# else: -# sparse_indices = np.unique(overlapping_spikes["channel_index"]) - -# channel_ids = recording.channel_ids[sparse_indices] - -# center_spike = spikes[spike_index]["sample_index"] -# max_delta = np.max([np.abs(center_spike - overlapping_spikes[0]["sample_index"]), -# np.abs(center_spike - overlapping_spikes[-1]["sample_index"])]) -# sf = center_spike - max_delta - cut_out_samples -# ef = center_spike + max_delta + cut_out_samples -# tr_overlap = recording.get_traces(start_frame=sf, -# end_frame=ef, -# channel_ids=channel_ids, return_scaled=True) -# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 -# max_tr = np.max(np.abs(tr_overlap)) -# fig, ax = plt.subplots() -# for ch, tr in enumerate(tr_overlap.T): -# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") -# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") - -# used_labels = [] -# for spike in overlapping_spikes: -# label = f"U{spike['unit_index']}" -# if label in used_labels: -# label = None -# else: -# used_labels.append(label) -# ax.axvline(spike["sample_index"] / recording.sampling_frequency * 1000, -# color=f"C{spike['unit_index']}", label=label) - -# if fitted_amps is not None: -# fitted_traces = np.zeros_like(tr_overlap) - -# all_templates = we.get_all_templates() -# for i, spike in enumerate(overlapping_spikes): -# template = all_templates[spike["unit_index"]] -# template_scaled = fitted_amps[overlap_indices[i]] * template -# template_scaled_sparse = template_scaled[:, sparse_indices] -# sample_start = spike["sample_index"] - we.nbefore -# sample_end = sample_start + template_scaled_sparse.shape[0] - -# fitted_traces[sample_start - sf: sample_end - sf] += template_scaled_sparse - -# for ch, temp in enumerate(template_scaled_sparse.T): - -# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 -# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", -# ls="--") - -# for ch, tr in enumerate(fitted_traces.T): -# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) - -# fitted_line = ax.get_lines()[-1] -# fitted_line.set_label("Fitted") - - -# ax.legend() -# ax.set_title(f"Spike {spike_index} - sample {center_spike}") -# return tr_overlap, ax + ax.legend() + ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") + return ax From 7bec9df5c0298c853f552989ef5e3febcf0f9470 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 28 Aug 2023 17:52:07 +0200 Subject: [PATCH 12/73] Simplify and cleanup --- .../postprocessing/amplitude_scalings.py | 484 ++++++++---------- 1 file changed, 206 insertions(+), 278 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index d367ef4f22..1f7923eb05 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,6 +7,9 @@ from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# DEBUG = True + + class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ Computes amplitude scalings from WaveformExtractor. @@ -21,7 +24,6 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) - self.overlapping_mask = None def _set_params( self, @@ -30,7 +32,6 @@ def _set_params( ms_before, ms_after, handle_collisions, - max_consecutive_collisions, delta_collision_ms, ): params = dict( @@ -39,7 +40,6 @@ def _set_params( ms_before=ms_before, ms_after=ms_after, handle_collisions=handle_collisions, - max_consecutive_collisions=max_consecutive_collisions, delta_collision_ms=delta_collision_ms, ) return params @@ -63,7 +63,6 @@ def _run(self, **job_kwargs): # collisions handle_collisions = self._params["handle_collisions"] - max_consecutive_collisions = self._params["max_consecutive_collisions"] delta_collision_ms = self._params["delta_collision_ms"] delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) @@ -120,7 +119,6 @@ def _run(self, **job_kwargs): cut_out_after, return_scaled, handle_collisions, - max_consecutive_collisions, delta_collision_samples, ) processor = ChunkRecordingExecutor( @@ -133,32 +131,16 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings, overlapping_mask) = zip(*out) + (amp_scalings, collisions) = zip(*out) amp_scalings = np.concatenate(amp_scalings) - if handle_collisions > 0: - from ..core.job_tools import divide_recording_into_chunks - - overlapping_mask_corrected = [] - all_chunks = divide_recording_into_chunks(processor.recording, processor.chunk_size) - num_spikes_so_far = 0 - for i, overlapping in enumerate(overlapping_mask): - if i == 0: - continue - segment_index = all_chunks[i - 1][0] - spikes_in_segment = self.spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][1]) - i1 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][2]) - num_spikes_so_far += i1 - i0 - overlapping_corrected = overlapping.copy() - overlapping_corrected[overlapping_corrected >= 0] += num_spikes_so_far - overlapping_mask_corrected.append(overlapping_corrected) - overlapping_mask = np.concatenate(overlapping_mask_corrected) - print(f"Found {len(overlapping_mask)} overlapping spikes") - self.overlapping_mask = overlapping_mask - else: - overlapping_mask = np.concatenate(overlapping_mask) + + collisions_dict = {} + if handle_collisions: + for collision in collisions: + collisions_dict.update(collision) self._extension_data[f"amplitude_scalings"] = amp_scalings + self._extension_data[f"collisions"] = collisions_dict def get_data(self, outputs="concatenated"): """ @@ -206,7 +188,6 @@ def compute_amplitude_scalings( ms_before=None, ms_after=None, handle_collisions=False, - max_consecutive_collisions=3, delta_collision_ms=2, load_if_exists=False, outputs="concatenated", @@ -235,10 +216,8 @@ def compute_amplitude_scalings( Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. - max_consecutive_collisions: int, default: 3 - The maximum number of consecutive collisions to handle on each side of a spike. delta_collision_ms: float, default: 2 - The maximum time difference in ms between two spikes to be considered as colliding. + The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. outputs: str, default: 'concatenated' @@ -264,7 +243,6 @@ def compute_amplitude_scalings( ms_before=ms_before, ms_after=ms_after, handle_collisions=handle_collisions, - max_consecutive_collisions=max_consecutive_collisions, delta_collision_ms=delta_collision_ms, ) sac.run(**job_kwargs) @@ -288,7 +266,6 @@ def _init_worker_amplitude_scalings( cut_out_after, return_scaled, handle_collisions, - max_consecutive_collisions, delta_collision_samples, ): # create a local dict per worker @@ -304,15 +281,15 @@ def _init_worker_amplitude_scalings( worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices worker_ctx["handle_collisions"] = handle_collisions - worker_ctx["max_consecutive_collisions"] = max_consecutive_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples if not handle_collisions: worker_ctx["margin"] = max(nbefore, nafter) else: + # in this case we extend the margin to be able to get with collisions outside the chunk margin_waveforms = max(nbefore, nafter) - max_margin_collisions = int(max_consecutive_collisions * delta_collision_samples) - worker_ctx["margin"] = max(margin_waveforms, max_margin_collisions) + max_margin_collisions = delta_collision_samples + margin_waveforms + worker_ctx["margin"] = max_margin_collisions return worker_ctx @@ -332,7 +309,6 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] handle_collisions = worker_ctx["handle_collisions"] - max_consecutive_collisions = worker_ctx["max_consecutive_collisions"] delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] @@ -355,17 +331,21 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: - overlapping_mask = find_overlapping_mask( - local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices + # local spikes with margin! + i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) + i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] + collisions = find_collisions( + local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices ) - overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] else: - overlapping_spike_indices = np.array([], dtype=int) + collisions = {} - # get all waveforms + # compute the scaling for each spike scalings = np.zeros(len(local_spikes), dtype=float) + collisions_dict = {} for spike_index, spike in enumerate(local_spikes): - if spike_index in overlapping_spike_indices: + if spike_index in collisions.keys(): # we deal with overlapping spikes later continue unit_index = spike["unit_index"] @@ -390,14 +370,13 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) scalings[spike_index] = linregress_res[0] # deal with collisions - if len(overlapping_spike_indices) > 0: - for overlapping in overlapping_mask: - # the current spike is the one at the 'max_consecutive_collisions' position - spike_index = overlapping[max_consecutive_collisions] - overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] - current_spike_index_within_overlapping = np.where(overlapping >= 0)[0] == max_consecutive_collisions + if len(collisions) > 0: + num_spikes_in_previous_segments = int( + np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) + ) + for spike_index, collision in collisions.items(): scaled_amps = fit_collision( - overlapping_spikes, + collision, traces_with_margin, start_frame, end_frame, @@ -409,13 +388,16 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_before, cut_out_after, ) - # get the right amplitude scaling - scalings[spike_index] = scaled_amps[current_spike_index_within_overlapping] + # the scaling for the current spike is at index 0 + scalings[spike_index] = scaled_amps[0] + + # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments + collisions_dict.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) else: scalings = np.array([]) - overlapping_mask = np.array([], shape=(0, max_consecutive_collisions + 1)) + collisions_dict = {} - return (scalings, overlapping_mask) + return (scalings, collisions_dict) ### Collision handling ### @@ -443,111 +425,65 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): return False -def find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): """ - Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape - (n_spikes, 2 * max_consecutive_spikes + 1). + Finds the collisions between spikes. Parameters ---------- spikes: np.array An array of spikes - max_consecutive_spikes: int - The maximum number of consecutive spikes to consider - delta_overlap_samples: int + spikes_w_margin: np.array + An array of spikes within the added margin + delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping unit_inds_to_channel_indices: dict A dictionary mapping unit indices to channel indices Returns ------- - overlapping_mask: np.array - A boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1) where the central column (max_consecutive_spikes) - is the current spike index, while the other columns are the indices of the overlapping spikes. The first - max_consecutive_spikes columns are the pre-overlapping spikes, while the last max_consecutive_spikes columns are - the post-overlapping spikes. + collision_spikes_dict: np.array + A dictionary with collisions. The key is the index of the spike with collision, the value is an + array of overlapping spikes, including the spike itself at position 0. """ + collision_spikes_dict = {} + for spike_index, spike in enumerate(spikes): + # find the index of the spike within the spikes_w_margin + spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] - # overlapping_mask is a matrix of shape (n_spikes, 2 * max_consecutive_spikes + 1) - # the central column (max_consecutive_spikes) is the current spike index, while the other columns are the - # indices of the overlapping spikes. The first max_consecutive_spikes columns are the pre-overlapping spikes, - # while the last max_consecutive_spikes columns are the post-overlapping spikes - # Rows with all -1 are non-colliding spikes and are removed later - overlapping_mask_full = -1 * np.ones((len(spikes), 2 * max_consecutive_spikes + 1), dtype=int) - overlapping_mask_full[:, max_consecutive_spikes] = np.arange(len(spikes)) - - for i, spike in enumerate(spikes): - # find the possible spikes per and post within max_consecutive_spikes * delta_overlap_samples + # find the possible spikes per and post within delta_collision_samples consecutive_window_pre = np.searchsorted( - spikes["sample_index"], - spike["sample_index"] - max_consecutive_spikes * delta_overlap_samples, + spikes_w_margin["sample_index"], + spike["sample_index"] - delta_collision_samples, ) consecutive_window_post = np.searchsorted( - spikes["sample_index"], - spike["sample_index"] + max_consecutive_spikes * delta_overlap_samples, + spikes_w_margin["sample_index"], + spike["sample_index"] + delta_collision_samples, + ) + # exclude the spike itself (it is included in the collision_spikes by construction) + pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) + post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) + possible_overlapping_spike_indices = np.concatenate( + (pre_possible_consecutive_spike_indices, post_possible_consecutive_spike_indices) ) - pre_possible_consecutive_spikes = spikes[consecutive_window_pre:i][::-1] - post_possible_consecutive_spikes = spikes[i + 1 : consecutive_window_post] - - # here we fill in the overlapping information by consecutively looping through the possible consecutive spikes - # and checking the spatial overlap and the delay with the previous overlapping spike - # pre and post are hanlded separately. Note that the pre-spikes are already sorted backwards - - # overlap_rank keeps track of the rank of consecutive collisions (i.e., rank 0 is the first, rank 1 is the second, etc.) - # this is needed because we are just considering spikes with spatial overlap, while the possible consecutive spikes - # only looked at the temporal overlap - overlap_rank = 0 - if len(pre_possible_consecutive_spikes) > 0: - for c_pre, spike_consecutive_pre in enumerate(pre_possible_consecutive_spikes[::-1]): - if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_pre["unit_index"] - ): - if ( - spikes[overlapping_mask_full[i, max_consecutive_spikes - overlap_rank]]["sample_index"] - - spike_consecutive_pre["sample_index"] - < delta_overlap_samples - ): - overlapping_mask_full[i, max_consecutive_spikes - overlap_rank - 1] = i - 1 - c_pre - overlap_rank += 1 - else: - break - # if overlap_rank > 1: - # print(f"\tHigher order pre-overlap for spike {i}!") - - overlap_rank = 0 - if len(post_possible_consecutive_spikes) > 0: - for c_post, spike_consecutive_post in enumerate(post_possible_consecutive_spikes): - if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_post["unit_index"] - ): - if ( - spike_consecutive_post["sample_index"] - - spikes[overlapping_mask_full[i, max_consecutive_spikes + overlap_rank]]["sample_index"] - < delta_overlap_samples - ): - overlapping_mask_full[i, max_consecutive_spikes + overlap_rank + 1] = i + 1 + c_post - overlap_rank += 1 - else: - break - # if overlap_rank > 1: - # print(f"\tHigher order post-overlap for spike {i}!") - - # in case no collisions were found, we set the central column to -1 so that we can easily identify the non-colliding spikes - if np.sum(overlapping_mask_full[i] != -1) == 1: - overlapping_mask_full[i, max_consecutive_spikes] = -1 - - # only return rows with collisions - overlapping_inds = [] - for i, overlapping in enumerate(overlapping_mask_full): - if np.any(overlapping >= 0): - overlapping_inds.append(i) - overlapping_mask = overlapping_mask_full[overlapping_inds] - - return overlapping_mask + + # find the overlapping spikes in space as well + for possible_overlapping_spike_index in possible_overlapping_spike_indices: + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, + spike["unit_index"], + spikes_w_margin[possible_overlapping_spike_index]["unit_index"], + ): + if spike_index not in collision_spikes_dict: + collision_spikes_dict[spike_index] = np.array([spike]) + collision_spikes_dict[spike_index] = np.concatenate( + (collision_spikes_dict[spike_index], [spikes_w_margin[possible_overlapping_spike_index]]) + ) + return collision_spikes_dict def fit_collision( - overlapping_spikes, + collision, traces_with_margin, start_frame, end_frame, @@ -558,15 +494,16 @@ def fit_collision( unit_inds_to_channel_indices, cut_out_before, cut_out_after, - debug=True, ): """ Compute the best fit for a collision between a spike and its overlapping spikes. + The function first cuts out the traces around the spike and its overlapping spikes, then + fits a multi-linear regression model to the traces using the centered templates as predictors. Parameters ---------- - overlapping_spikes: np.ndarray - A numpy array of shape (n_overlapping_spikes, ) containing the overlapping spikes (spike_dtype). + collision: np.ndarray + A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype). traces_with_margin: np.ndarray A numpy array of shape (n_samples, n_channels) containing the traces with a margin. start_frame: int @@ -591,30 +528,30 @@ def fit_collision( Returns ------- np.ndarray - The fitted scaling factors for the overlapping spikes. + The fitted scaling factors for the colliding spikes. """ from sklearn.linear_model import LinearRegression - sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left - sample_last_centered = overlapping_spikes[-1]["sample_index"] - start_frame - left + # make center of the spike externally + sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) + sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity sparse_indices = np.array([], dtype="int") - for spike in overlapping_spikes: + for spike in collision: sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] sparse_indices = np.union1d(sparse_indices, sparse_indices_i) - # TODO: check alignment!!! local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] y = local_waveform.T.flatten() - X = np.zeros((len(y), len(overlapping_spikes))) - for i, spike in enumerate(overlapping_spikes): + X = np.zeros((len(y), len(collision))) + for i, spike in enumerate(collision): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - start_frame - left - local_waveform_start + sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -626,136 +563,127 @@ def fit_collision( full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() - if debug: - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - max_tr = np.max(np.abs(local_waveform)) - - _ = ax.plot(y, color="k") - - for i, spike in enumerate(overlapping_spikes): - _ = ax.plot(X[:, i], color=f"C{i}", alpha=0.5) - plt.show() - reg = LinearRegression().fit(X, y) scalings = reg.coef_ return scalings -def plot_collisions(we, sparsity=None, num_collisions=None): - """ - Plot the fitting of collision spikes. - - Parameters - ---------- - we : WaveformExtractor - The WaveformExtractor object. - sparsity : ChannelSparsity, default=None - The ChannelSparsity. If None, only main channels are plotted. - num_collisions : int, default=None - Number of collisions to plot. If None, all collisions are plotted. - """ - assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" - sac = we.load_extension("amplitude_scalings") - handle_collisions = sac._params["handle_collisions"] - assert handle_collisions, "Amplitude scalings was run without handling collisions!" - scalings = sac.get_data() - - overlapping_mask = sac.overlapping_mask - num_collisions = num_collisions or len(overlapping_mask) - spikes = sac.spikes - max_consecutive_collisions = sac._params["max_consecutive_collisions"] - - for i in range(num_collisions): - ax = _plot_one_collision( - we, overlapping_mask[i], spikes, scalings=scalings, max_consecutive_collisions=max_consecutive_collisions - ) - - -def _plot_one_collision( - we, - overlap, - spikes, - scalings=None, - sparsity=None, - cut_out_samples=100, - max_consecutive_collisions=3, -): - import matplotlib.pyplot as plt - - recording = we.recording - nbefore_nafter_max = max(we.nafter, we.nbefore) - cut_out_samples = max(cut_out_samples, nbefore_nafter_max) - spike_index = overlap[max_consecutive_collisions] - overlap_indices = overlap[overlap != -1] - overlapping_spikes = spikes[overlap_indices] - - if sparsity is not None: - unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices - sparse_indices = np.array([], dtype="int") - for spike in overlapping_spikes: - sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) - else: - sparse_indices = np.unique(overlapping_spikes["channel_index"]) - - channel_ids = recording.channel_ids[sparse_indices] - - center_spike = spikes[spike_index] - max_delta = np.max( - [ - np.abs(center_spike["sample_index"] - overlapping_spikes[0]["sample_index"]), - np.abs(center_spike["sample_index"] - overlapping_spikes[-1]["sample_index"]), - ] - ) - sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) - ef = min( - center_spike["sample_index"] + max_delta + cut_out_samples, - recording.get_num_samples(segment_index=center_spike["segment_index"]), - ) - tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) - ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 - max_tr = np.max(np.abs(tr_overlap)) - fig, ax = plt.subplots() - for ch, tr in enumerate(tr_overlap.T): - _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") - ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") - - used_labels = [] - for spike in overlapping_spikes: - label = f"U{spike['unit_index']}" - if label in used_labels: - label = None - else: - used_labels.append(label) - ax.axvline( - spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label - ) - - if scalings is not None: - fitted_traces = np.zeros_like(tr_overlap) - - all_templates = we.get_all_templates() - for i, spike in enumerate(overlapping_spikes): - template = all_templates[spike["unit_index"]] - template_scaled = scalings[overlap_indices[i]] * template - template_scaled_sparse = template_scaled[:, sparse_indices] - sample_start = spike["sample_index"] - we.nbefore - sample_end = sample_start + template_scaled_sparse.shape[0] - - fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse - - for ch, temp in enumerate(template_scaled_sparse.T): - ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 - _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") - - for ch, tr in enumerate(fitted_traces.T): - _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) - - fitted_line = ax.get_lines()[-1] - fitted_line.set_label("Fitted") - - ax.legend() - ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") - return ax +# uncomment for debugging +# def plot_collisions(we, sparsity=None, num_collisions=None): +# """ +# Plot the fitting of collision spikes. + +# Parameters +# ---------- +# we : WaveformExtractor +# The WaveformExtractor object. +# sparsity : ChannelSparsity, default=None +# The ChannelSparsity. If None, only main channels are plotted. +# num_collisions : int, default=None +# Number of collisions to plot. If None, all collisions are plotted. +# """ +# assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" +# sac = we.load_extension("amplitude_scalings") +# handle_collisions = sac._params["handle_collisions"] +# assert handle_collisions, "Amplitude scalings was run without handling collisions!" +# scalings = sac.get_data() + +# # overlapping_mask = sac.overlapping_mask +# # num_collisions = num_collisions or len(overlapping_mask) +# spikes = sac.spikes +# collisions = sac._extension_data[f"collisions"] +# collision_keys = list(collisions.keys()) +# num_collisions = num_collisions or len(collisions) +# num_collisions = min(num_collisions, len(collisions)) + +# for i in range(num_collisions): +# overlapping_spikes = collisions[collision_keys[i]] +# ax = _plot_one_collision( +# we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity +# ) + + +# def _plot_one_collision( +# we, +# spike_index, +# overlapping_spikes, +# spikes, +# scalings=None, +# sparsity=None, +# cut_out_samples=100, +# ): +# import matplotlib.pyplot as plt + +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = overlapping_spikes[0] +# max_delta = np.max( +# [ +# np.abs(center_spike["sample_index"] - np.min(overlapping_spikes[1:]["sample_index"])), +# np.abs(center_spike["sample_index"] - np.max(overlapping_spikes[1:]["sample_index"])), +# ] +# ) +# sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) +# ef = min( +# center_spike["sample_index"] + max_delta + cut_out_samples, +# recording.get_num_samples(segment_index=center_spike["segment_index"]), +# ) +# tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) +# fig, ax = plt.subplots() +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for i, spike in enumerate(overlapping_spikes): +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline( +# spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label +# ) + +# if scalings is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# overlap_index = np.where(spikes == spike)[0][0] +# template_scaled = scalings[overlap_index] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") +# return ax From 2a0e042bdaa186836377d02d01ca83c7d06b71d3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 29 Aug 2023 09:24:43 +0200 Subject: [PATCH 13/73] Implement SpikeRetriever. --- src/spikeinterface/core/node_pipeline.py | 102 ++++++++- .../core/tests/test_node_pipeline.py | 199 ++++++++++-------- 2 files changed, 205 insertions(+), 96 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 9ea5ad59e7..ff747fe2a0 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -84,7 +84,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, *ar raise NotImplementedError -# nodes graph must have either a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) +# nodes graph must have a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever) # as first element they play the same role in pipeline : give some peaks (and eventually more) @@ -138,7 +138,103 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): - pass + """ + This class is usefull to inject a sorting object in the node pipepline mechanisim. + It allows to compute some post processing with the same machinery used for sorting components. + This is a first step to totaly refactor: + * compute_spike_locations() + * compute_amplitude_scalings() + * compute_spike_amplitudes() + * compute_principal_components() + + + recording: + + sorting: + + channel_from_template: bool (default True) + If True then the channel_index is infered from template and extremum_channel_inds must be provided. + If False every spikes compute its own channel index given a radius around the template max channel. + extremum_channel_inds: dict of int + The extremum channel index dict given from template. + radius_um: float (default 50.) + The radius to find the real max channel. + Used only when channel_from_template=False + peak_sign: str (default "neg") + Peak sign to find the max channel. + Used only when channel_from_template=False + """ + def __init__(self, recording, sorting, + channel_from_template=True, + extremum_channel_inds=None, + radius_um=50, + peak_sign="neg" + ): + PipelineNode.__init__(self, recording, return_output=False) + + self.channel_from_template = channel_from_template + + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" + + self.peaks = sorting_to_peak(sorting, extremum_channel_inds) + + if not channel_from_template: + channel_distance = get_channel_distances(recording) + self.neighbours_mask = channel_distance < radius_um + self.peak_sign = peak_sign + + + # precompute segment slice + self.segment_slices = [] + for segment_index in range(recording.get_num_segments()): + i0 = np.searchsorted(self.peaks["segment_index"], segment_index) + i1 = np.searchsorted(self.peaks["segment_index"], segment_index + 1) + self.segment_slices.append(slice(i0, i1)) + + def get_trace_margin(self): + return 0 + + def get_dtype(self): + return base_peak_dtype + + def compute(self, traces, start_frame, end_frame, segment_index, max_margin): + # get local peaks + sl = self.segment_slices[segment_index] + peaks_in_segment = self.peaks[sl] + i0 = np.searchsorted(peaks_in_segment["sample_index"], start_frame) + i1 = np.searchsorted(peaks_in_segment["sample_index"], end_frame) + local_peaks = peaks_in_segment[i0:i1] + + # make sample index local to traces + local_peaks = local_peaks.copy() + local_peaks["sample_index"] -= start_frame - max_margin + + if not self.channel_from_template: + # handle channel spike per spike + for i, peak in enumerate(local_peaks): + chans = np.flatnonzero(self.neighbours_mask[peak["channel_index"]]) + sparse_wfs = traces[peak["sample_index"], chans] + if self.peak_sign == "neg": + local_peaks[i]["channel_index"] = chans[np.argmin(sparse_wfs)] + elif self.peak_sign == "pos": + local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] + elif self.peak_sign == "both": + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + + # TODO: "amplitude" ??? + + return (local_peaks,) + + +def sorting_to_peak(sorting, extremum_channel_inds): + spikes = sorting.to_spike_vector() + peaks = np.zeros(spikes.size, dtype=base_peak_dtype) + peaks["sample_index"] = spikes["sample_index"] + extremum_channel_inds_ = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) + peaks["channel_index"] = extremum_channel_inds_[spikes["unit_index"]] + peaks["amplitude"] = 0.0 + peaks["segment_index"] = spikes["segment_index"] + return peaks class WaveformsNode(PipelineNode): @@ -423,7 +519,7 @@ def _compute_peak_pipeline_chunk(segment_index, start_frame, end_frame, worker_c node_output = node.compute(trace_detection, start_frame, end_frame, segment_index, max_margin) # set sample index to local node_output[0]["sample_index"] += extra_margin - elif isinstance(node, PeakRetriever): + elif isinstance(node, PeakSource): node_output = node.compute(traces_chunk, start_frame, end_frame, segment_index, max_margin) else: # TODO later when in master: change the signature of all nodes (or maybe not!) diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index e9dfb43a66..35388a33a5 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -14,9 +14,10 @@ from spikeinterface.core.node_pipeline import ( run_node_pipeline, PeakRetriever, + SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - base_peak_dtype, + sorting_to_peak, ) @@ -77,7 +78,8 @@ def test_run_node_pipeline(): # recording = MEArecRecordingExtractor(local_path) recording, sorting = read_mearec(local_path) - job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) spikes = sorting.to_spike_vector() @@ -88,98 +90,109 @@ def test_run_node_pipeline(): # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - print(extremum_channel_inds) - ext_channel_inds = np.array([extremum_channel_inds[unit_id] for unit_id in sorting.unit_ids]) - print(ext_channel_inds) - peaks = np.zeros(spikes.size, dtype=base_peak_dtype) - peaks["sample_index"] = spikes["sample_index"] - peaks["channel_index"] = ext_channel_inds[spikes["unit_index"]] - peaks["amplitude"] = 0.0 - peaks["segment_index"] = 0 - - # one step only : squeeze output - peak_retriever = PeakRetriever(recording, peaks) - nodes = [ - peak_retriever, - AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6), - ] - step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) - assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) - - # 3 nodes two have outputs - ms_before = 0.5 - ms_after = 1.0 + peaks = sorting_to_peak(sorting, extremum_channel_inds) + + peak_retriever = PeakRetriever(recording, peaks) - dense_waveforms = ExtractDenseWaveforms( - recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False - ) - waveform_denoiser = WaveformDenoiser(recording, parents=[peak_retriever, dense_waveforms], return_output=False) - amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_retriever], param0=6.6, return_output=True) - waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_retriever, dense_waveforms], return_output=True) - denoised_waveforms_rms = WaveformsRootMeanSquare( - recording, parents=[peak_retriever, waveform_denoiser], return_output=True - ) - - nodes = [ - peak_retriever, - dense_waveforms, - waveform_denoiser, - amplitue_extraction, - waveforms_rms, - denoised_waveforms_rms, - ] - - # gather memory mode - output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") - amplitudes, waveforms_rms, denoised_waveforms_rms = output - assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) - - num_peaks = peaks.shape[0] - num_channels = recording.get_num_channels() - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - assert waveforms_rms.shape[0] == num_peaks - assert waveforms_rms.shape[1] == num_channels - - # gather npy mode - folder = cache_folder / "pipeline_folder" - if folder.is_dir(): - shutil.rmtree(folder) - output = run_node_pipeline( - recording, - nodes, - job_kwargs, - gather_mode="npy", - folder=folder, - names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], - ) - amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output - - amplitudes_file = folder / "amplitudes.npy" - assert amplitudes_file.is_file() - amplitudes3 = np.load(amplitudes_file) - assert np.array_equal(amplitudes, amplitudes2) - assert np.array_equal(amplitudes2, amplitudes3) - - waveforms_rms_file = folder / "waveforms_rms.npy" - assert waveforms_rms_file.is_file() - waveforms_rms3 = np.load(waveforms_rms_file) - assert np.array_equal(waveforms_rms, waveforms_rms2) - assert np.array_equal(waveforms_rms2, waveforms_rms3) - - denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" - assert denoised_waveforms_rms_file.is_file() - denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) - assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) - assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) - - # Test pickle mechanism - for node in nodes: - import pickle - - pickled_node = pickle.dumps(node) - unpickled_node = pickle.loads(pickled_node) + # channel index is from template + spike_retriever_T = SpikeRetriever(recording, sorting, + channel_from_template=True, + extremum_channel_inds=extremum_channel_inds) + # channel index is per spike + spike_retriever_S = SpikeRetriever(recording, sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg") + + # test with 2 diffrents first node + for peak_source in (peak_retriever, spike_retriever_T, spike_retriever_S): + + + + + # one step only : squeeze output + nodes = [ + peak_source, + AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6), + ] + step_one = run_node_pipeline(recording, nodes, job_kwargs, squeeze_output=True) + assert np.allclose(np.abs(peaks["amplitude"]), step_one["abs_amplitude"]) + + # 3 nodes two have outputs + ms_before = 0.5 + ms_after = 1.0 + peak_retriever = PeakRetriever(recording, peaks) + dense_waveforms = ExtractDenseWaveforms( + recording, parents=[peak_source], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + waveform_denoiser = WaveformDenoiser(recording, parents=[peak_source, dense_waveforms], return_output=False) + amplitue_extraction = AmplitudeExtractionNode(recording, parents=[peak_source], param0=6.6, return_output=True) + waveforms_rms = WaveformsRootMeanSquare(recording, parents=[peak_source, dense_waveforms], return_output=True) + denoised_waveforms_rms = WaveformsRootMeanSquare( + recording, parents=[peak_source, waveform_denoiser], return_output=True + ) + + nodes = [ + peak_source, + dense_waveforms, + waveform_denoiser, + amplitue_extraction, + waveforms_rms, + denoised_waveforms_rms, + ] + + # gather memory mode + output = run_node_pipeline(recording, nodes, job_kwargs, gather_mode="memory") + amplitudes, waveforms_rms, denoised_waveforms_rms = output + assert np.allclose(np.abs(peaks["amplitude"]), amplitudes["abs_amplitude"]) + + num_peaks = peaks.shape[0] + num_channels = recording.get_num_channels() + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + assert waveforms_rms.shape[0] == num_peaks + assert waveforms_rms.shape[1] == num_channels + + # gather npy mode + folder = cache_folder / "pipeline_folder" + if folder.is_dir(): + shutil.rmtree(folder) + output = run_node_pipeline( + recording, + nodes, + job_kwargs, + gather_mode="npy", + folder=folder, + names=["amplitudes", "waveforms_rms", "denoised_waveforms_rms"], + ) + amplitudes2, waveforms_rms2, denoised_waveforms_rms2 = output + + amplitudes_file = folder / "amplitudes.npy" + assert amplitudes_file.is_file() + amplitudes3 = np.load(amplitudes_file) + assert np.array_equal(amplitudes, amplitudes2) + assert np.array_equal(amplitudes2, amplitudes3) + + waveforms_rms_file = folder / "waveforms_rms.npy" + assert waveforms_rms_file.is_file() + waveforms_rms3 = np.load(waveforms_rms_file) + assert np.array_equal(waveforms_rms, waveforms_rms2) + assert np.array_equal(waveforms_rms2, waveforms_rms3) + + denoised_waveforms_rms_file = folder / "denoised_waveforms_rms.npy" + assert denoised_waveforms_rms_file.is_file() + denoised_waveforms_rms3 = np.load(denoised_waveforms_rms_file) + assert np.array_equal(denoised_waveforms_rms, denoised_waveforms_rms2) + assert np.array_equal(denoised_waveforms_rms2, denoised_waveforms_rms3) + + # Test pickle mechanism + for node in nodes: + import pickle + + pickled_node = pickle.dumps(node) + unpickled_node = pickle.loads(pickled_node) if __name__ == "__main__": From 2d7b08f2744c550dd630add451e85c28f4f7336d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 09:55:09 +0200 Subject: [PATCH 14/73] Add zugbruecke in extractors install for plexon2 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3ecfbe2718..ddb0de4893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ extractors = [ "ONE-api>=1.19.1", "ibllib>=2.21.0", "pymatreader>=0.0.32", # For cell explorer matlab files + "zugbruecke", # For plexon2 ] streaming_extractors = [ From 4a9f429f2ad3b18057db6c6432960c401f0ff14c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 09:57:16 +0200 Subject: [PATCH 15/73] Update naming following #1626 --- src/spikeinterface/extractors/neoextractors/plexon2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index c3869dbadc..148deb48e9 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -52,7 +52,7 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): mode = "file" NeoRawIOClass = "Plexon2RawIO" - handle_spike_frame_directly = True + neo_returns_frames = True name = "plexon2" def __init__(self, file_path): From aae39c6d1a2f4c3f952e19a81e031eb7abb909ae Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 10:34:17 +0200 Subject: [PATCH 16/73] Install wine for plexon2 --- .github/actions/build-test-environment/action.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 004fe31203..7b5debdd51 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -40,3 +40,9 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH shell: bash + - name: Install wine (needed for Plexon2) + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo dpkg --add-architecture i386 + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine From 26fdba2b0d85f5c174f60462d6edb6128876c14a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 10:36:23 +0200 Subject: [PATCH 17/73] Add shell --- .github/actions/build-test-environment/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 7b5debdd51..b056bd3353 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -46,3 +46,4 @@ runs: sudo dpkg --add-architecture i386 sudo apt-get update -qq sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine + shell: bash From 8bcec5f9df9d2205b0ecd222aac5df135492d730 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 11:39:12 +0200 Subject: [PATCH 18/73] Expose sampling_frequency in pl2 sorting (needed for multi-stream) --- src/spikeinterface/extractors/neoextractors/plexon2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 148deb48e9..966fc253ad 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -48,6 +48,8 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): ---------- file_path: str The file path to load the recordings from. + sampling_frequency: float, default: None + The sampling frequency of the sorting (required for multiple streams with different sampling frequencies). """ mode = "file" @@ -55,13 +57,13 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): neo_returns_frames = True name = "plexon2" - def __init__(self, file_path): + def __init__(self, file_path, sampling_frequency=None): from neo.rawio import Plexon2RawIO neo_kwargs = self.map_to_neo_kwargs(file_path) neo_reader = Plexon2RawIO(**neo_kwargs) neo_reader.parse_header() - NeoBaseSortingExtractor.__init__(self, **neo_kwargs) + NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) self._kwargs.update({"file_path": str(file_path)}) @classmethod From 3ffb76c444e2556fd62efbfab677d6dcd1cd7706 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 11:59:55 +0200 Subject: [PATCH 19/73] Add sampling_frequency kwargs in tests --- src/spikeinterface/extractors/tests/test_neoextractors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 0405d7b129..e8f565bede 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -156,9 +156,7 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] - entities = [ - ("plexon/4chDemoPL2.pl2"), - ] + entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): @@ -328,7 +326,7 @@ def test_pickling(self): # test = PlexonRecordingTest() # test = PlexonSortingTest() # test = NeuralynxRecordingTest() - test = BlackrockRecordingTest() + test = Plexon2RecordingTest() # test = MCSRawRecordingTest() # test = KiloSortSortingTest() # test = Spike2RecordingTest() From 6247dc090b5604dca5dd73fb151f55f781bca8d3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 12:55:33 +0200 Subject: [PATCH 20/73] Update self._kwargs --- src/spikeinterface/extractors/neoextractors/plexon2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 966fc253ad..8dbfc67e90 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -64,7 +64,7 @@ def __init__(self, file_path, sampling_frequency=None): neo_reader = Plexon2RawIO(**neo_kwargs) neo_reader.parse_header() NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(file_path), "sampling_frequency": sampling_frequency}) @classmethod def map_to_neo_kwargs(cls, file_path): From dfcd3caf8a18168ae05b564d62e7ce15c3ac185d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 16:46:55 +0200 Subject: [PATCH 21/73] Mark plexon2 tests and install Wine only if needed --- .../actions/build-test-environment/action.yml | 7 --- .github/actions/install-wine/action.yml | 21 ++++++++ .github/workflows/full-test-with-codecov.yml | 2 +- .github/workflows/full-test.yml | 12 ++++- .../extractors/tests/test_neoextractors.py | 48 ++++++++++--------- 5 files changed, 59 insertions(+), 31 deletions(-) create mode 100644 .github/actions/install-wine/action.yml diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index b056bd3353..004fe31203 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -40,10 +40,3 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH shell: bash - - name: Install wine (needed for Plexon2) - run: | - sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list - sudo dpkg --add-architecture i386 - sudo apt-get update -qq - sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine - shell: bash diff --git a/.github/actions/install-wine/action.yml b/.github/actions/install-wine/action.yml new file mode 100644 index 0000000000..3ae08ecd34 --- /dev/null +++ b/.github/actions/install-wine/action.yml @@ -0,0 +1,21 @@ +name: Install packages +description: This action installs the package and its dependencies for testing + +inputs: + python-version: + description: 'Python version to set up' + required: false + os: + description: 'Operating system to set up' + required: false + +runs: + using: "composite" + steps: + - name: Install wine (needed for Plexon2) + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo dpkg --add-architecture i386 + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine + shell: bash diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index a5561c2ffc..d0bf109a00 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -54,7 +54,7 @@ jobs: - name: run tests run: | source ${{ github.workspace }}/test_env/bin/activate - pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 + pytest -m "not sorters_external and not plexon2" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index ac5130bade..a343500c08 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -75,6 +75,10 @@ jobs: echo "Extractors changed" echo "EXTRACTORS_CHANGED=true" >> $GITHUB_OUTPUT fi + if [[ $file == *"plexon2"* ]]; then + echo "Plexon2 changed" + echo "PLEXON2_CHANGED=true" >> $GITHUB_OUTPUT + fi if [[ $file == *"/preprocessing/"* ]]; then echo "Preprocessing changed" echo "PREPROCESSING_CHANGED=true" >> $GITHUB_OUTPUT @@ -122,11 +126,14 @@ jobs: done - name: Set execute permissions on run_tests.sh run: chmod +x .github/run_tests.sh + - name: Install Wine (Plexon2) + if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} + uses: ./.github/actions/install-wine - name: Test core run: ./.github/run_tests.sh core - name: Test extractors if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} - run: ./.github/run_tests.sh "extractors and not streaming_extractors" + run: ./.github/run_tests.sh "extractors and not streaming_extractors and not plexon2" - name: Test preprocessing if: ${{ steps.modules-changed.outputs.PREPROCESSING_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh preprocessing @@ -157,3 +164,6 @@ jobs: - name: Test internal sorters if: ${{ steps.modules-changed.outputs.SORTERS_INTERNAL_CHANGED == 'true' || steps.modules-changed.outputs.SORTINGCOMPONENTS_CHANGED || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh sorters_internal + - name: Test plexon2 + if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} + run: ./.github/run_tests.sh plexon2 diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index e8f565bede..da162eccf1 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -121,14 +121,6 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): ] -class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): - ExtractorClass = Plexon2EventExtractor - downloads = ["plexon"] - entities = [ - ("plexon/4chDemoPL2.pl2"), - ] - - class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor downloads = ["plexon"] @@ -137,14 +129,6 @@ class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] -class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): - ExtractorClass = Plexon2RecordingExtractor - downloads = ["plexon"] - entities = [ - ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), - ] - - class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonSortingExtractor downloads = ["plexon"] @@ -153,12 +137,6 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ] -class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): - ExtractorClass = Plexon2SortingExtractor - downloads = ["plexon"] - entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] - - class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuralynxRecordingExtractor downloads = ["neuralynx"] @@ -312,6 +290,32 @@ def test_pickling(self): pass +# We mark plexon2 tests as they require additional dependencies (wine) +@pytest.mark.plexon2 +class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2RecordingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ] + + +@pytest.mark.plexon2 +class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2EventExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] + + +@pytest.mark.plexon2 +class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2SortingExtractor + downloads = ["plexon"] + entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] + + if __name__ == "__main__": # test = MearecSortingTest() # test = SpikeGLXRecordingTest() From 23f8677f148a819a57c1f859aa327ca40d124f25 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 17:33:15 +0200 Subject: [PATCH 22/73] Install zugbruecke not on win --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ddb0de4893..b0bf4cdcf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ extractors = [ "ONE-api>=1.19.1", "ibllib>=2.21.0", "pymatreader>=0.0.32", # For cell explorer matlab files - "zugbruecke", # For plexon2 + "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ From 6f902d46dad68d5f69322175537e7a89b4e0bd43 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 30 Aug 2023 11:11:39 +0200 Subject: [PATCH 23/73] patch --- src/spikeinterface/preprocessing/silence_periods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 223122e927..c2ffcc6843 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -46,7 +46,7 @@ def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs) num_seg = recording.get_num_segments() if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and not np.isscalar(list_periods[0]): + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: # when unique segment accept list instead of of list of list/arrays list_periods = [list_periods] From a797aa33c561871b10b4a441985d89546e8ebc2e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 30 Aug 2023 15:55:16 +0200 Subject: [PATCH 24/73] Improve debug plots and handle_collisions=True by default --- .../postprocessing/amplitude_scalings.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 1f7923eb05..a9b3898388 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -187,7 +187,7 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, - handle_collisions=False, + handle_collisions=True, delta_collision_ms=2, load_if_exists=False, outputs="concatenated", @@ -212,7 +212,7 @@ def compute_amplitude_scalings( ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. If None, the WaveformExtractor ms_after is used. - handle_collisions: bool, default: False + handle_collisions: bool, default: True Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. @@ -598,12 +598,12 @@ def fit_collision( # for i in range(num_collisions): # overlapping_spikes = collisions[collision_keys[i]] -# ax = _plot_one_collision( +# ax = plot_one_collision( # we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity # ) -# def _plot_one_collision( +# def plot_one_collision( # we, # spike_index, # overlapping_spikes, @@ -611,9 +611,13 @@ def fit_collision( # scalings=None, # sparsity=None, # cut_out_samples=100, +# ax=None # ): # import matplotlib.pyplot as plt +# if ax is None: +# fig, ax = plt.subplots() + # recording = we.recording # nbefore_nafter_max = max(we.nafter, we.nbefore) # cut_out_samples = max(cut_out_samples, nbefore_nafter_max) @@ -644,7 +648,7 @@ def fit_collision( # tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) # ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 # max_tr = np.max(np.abs(tr_overlap)) -# fig, ax = plt.subplots() + # for ch, tr in enumerate(tr_overlap.T): # _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") # ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") From b2e737ec734d8a488bf8997ca134ad7190b82d2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:46:48 +0000 Subject: [PATCH 25/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/node_pipeline.py | 19 ++++++-------- .../core/tests/test_node_pipeline.py | 25 +++++++++---------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index ff747fe2a0..610ae42398 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -141,7 +141,7 @@ class SpikeRetriever(PeakSource): """ This class is usefull to inject a sorting object in the node pipepline mechanisim. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is a first step to totaly refactor: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() @@ -164,16 +164,14 @@ class SpikeRetriever(PeakSource): Peak sign to find the max channel. Used only when channel_from_template=False """ - def __init__(self, recording, sorting, - channel_from_template=True, - extremum_channel_inds=None, - radius_um=50, - peak_sign="neg" - ): + + def __init__( + self, recording, sorting, channel_from_template=True, extremum_channel_inds=None, radius_um=50, peak_sign="neg" + ): PipelineNode.__init__(self, recording, return_output=False) self.channel_from_template = channel_from_template - + assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" self.peaks = sorting_to_peak(sorting, extremum_channel_inds) @@ -181,8 +179,7 @@ def __init__(self, recording, sorting, if not channel_from_template: channel_distance = get_channel_distances(recording) self.neighbours_mask = channel_distance < radius_um - self.peak_sign = peak_sign - + self.peak_sign = peak_sign # precompute segment slice self.segment_slices = [] @@ -219,7 +216,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): elif self.peak_sign == "pos": local_peaks[i]["channel_index"] = chans[np.argmax(sparse_wfs)] elif self.peak_sign == "both": - local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] + local_peaks[i]["channel_index"] = chans[np.argmax(np.abs(sparse_wfs))] # TODO: "amplitude" ??? diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 8bea0bafb1..d0d49b865c 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -81,25 +81,24 @@ def test_run_node_pipeline(): we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") peaks = sorting_to_peak(sorting, extremum_channel_inds) - + peak_retriever = PeakRetriever(recording, peaks) # channel index is from template - spike_retriever_T = SpikeRetriever(recording, sorting, - channel_from_template=True, - extremum_channel_inds=extremum_channel_inds) + spike_retriever_T = SpikeRetriever( + recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + ) # channel index is per spike - spike_retriever_S = SpikeRetriever(recording, sorting, - channel_from_template=False, - extremum_channel_inds=extremum_channel_inds, - radius_um=50, - peak_sign="neg") + spike_retriever_S = SpikeRetriever( + recording, + sorting, + channel_from_template=False, + extremum_channel_inds=extremum_channel_inds, + radius_um=50, + peak_sign="neg", + ) # test with 2 diffrents first node for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): - - - - # one step only : squeeze output nodes = [ peak_source, From 287e8af9621385d4fa835be6356b7695993cdc16 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 16:46:18 +0200 Subject: [PATCH 26/73] Fix tests --- src/spikeinterface/postprocessing/amplitude_scalings.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 0dd2587fba..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -24,6 +24,7 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) + self.collisions = None def _set_params( self, @@ -138,9 +139,11 @@ def _run(self, **job_kwargs): if handle_collisions: for collision in collisions: collisions_dict.update(collision) + self.collisions = collisions_dict + # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - self._extension_data[f"amplitude_scalings"] = amp_scalings - self._extension_data[f"collisions"] = collisions_dict + self._extension_data["amplitude_scalings"] = amp_scalings def get_data(self, outputs="concatenated"): """ From 280adf6e7260c69acd93f8fc4414cbbd569860a4 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 5 Sep 2023 14:59:13 -0400 Subject: [PATCH 27/73] update si_env for mac and windows, add error log --- installation_tips/check_your_install.py | 4 ++-- .../full_spikeinterface_environment_mac.yml | 13 +++++-------- .../full_spikeinterface_environment_windows.yml | 12 +++++------- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/installation_tips/check_your_install.py b/installation_tips/check_your_install.py index 20809ec6c0..c4aa860979 100644 --- a/installation_tips/check_your_install.py +++ b/installation_tips/check_your_install.py @@ -103,8 +103,8 @@ def _clean(): try: func() done = '...OK' - except: - done = '...Fail' + except Exception as err: + done = f'...Fail, Error: {err}' print(label, done) if platform.system() == "Windows": diff --git a/installation_tips/full_spikeinterface_environment_mac.yml b/installation_tips/full_spikeinterface_environment_mac.yml index 8b872981aa..867de7f98b 100755 --- a/installation_tips/full_spikeinterface_environment_mac.yml +++ b/installation_tips/full_spikeinterface_environment_mac.yml @@ -3,12 +3,10 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 - # numpy 1.21 break numba which break tridesclous - - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - numpy + - joblib - tqdm - matplotlib - h5py @@ -30,13 +28,12 @@ dependencies: - pip: # - PyQt5 - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 # - phy==2.0b5 - mountainsort4>=1.0.0 - mountainsort5>=0.3.0 diff --git a/installation_tips/full_spikeinterface_environment_windows.yml b/installation_tips/full_spikeinterface_environment_windows.yml index 8c793edcb1..38c26e6a78 100755 --- a/installation_tips/full_spikeinterface_environment_windows.yml +++ b/installation_tips/full_spikeinterface_environment_windows.yml @@ -3,12 +3,11 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 # numpy 1.21 break numba which break tridesclous - - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - numpy + - joblib - tqdm - matplotlib - h5py @@ -26,11 +25,10 @@ dependencies: - ipympl - pip: - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 # - phy==2.0b5 From 963914e4f25130e098aae650d8e38084191ed499 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Sep 2023 19:07:18 +0000 Subject: [PATCH 28/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- installation_tips/check_your_install.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/installation_tips/check_your_install.py b/installation_tips/check_your_install.py index c4aa860979..2b13a941cd 100644 --- a/installation_tips/check_your_install.py +++ b/installation_tips/check_your_install.py @@ -104,7 +104,7 @@ def _clean(): func() done = '...OK' except Exception as err: - done = f'...Fail, Error: {err}' + done = f'...Fail, Error: {err}' print(label, done) if platform.system() == "Windows": From 81a610d9bb036392c53e23b6546365c65b2ad68f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 Sep 2023 09:41:23 +0200 Subject: [PATCH 29/73] add neuroexplorer --- .../extractors/neoextractors/__init__.py | 2 + .../extractors/neoextractors/neuroexplorer.py | 66 +++++++++++++++++++ .../extractors/tests/test_neoextractors.py | 11 ++++ 3 files changed, 79 insertions(+) create mode 100644 src/spikeinterface/extractors/neoextractors/neuroexplorer.py diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0d9da1960a..a6c8f27ac3 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -16,6 +16,7 @@ read_neuroscope_sorting, read_neuroscope, ) +from .neuroexplorer import NeuroExplorerRecordingExtractor, read_neuroexplorer from .nix import NixRecordingExtractor, read_nix from .openephys import ( OpenEphysLegacyRecordingExtractor, @@ -53,6 +54,7 @@ SpikeGadgetsRecordingExtractor, SpikeGLXRecordingExtractor, TdtRecordingExtractor, + NeuroExplorerRecordingExtractor, ] neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py new file mode 100644 index 0000000000..e936d91fbf --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor + + +class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading NEX (NeuroExplorer data format) files. + + Based on :py:class:`neo.rawio.NeuroExplorerRawIO` + + Importantly, at the moment, this recorder only extracts one channel of the recording. + This is because the NeuroExplorerRawIO class does not support multi-channel recordings + as in the NeuroExplorer format they might have different sampling rates. + + Consider exctracting all the channels and then concatenating them with the concatenate_recordings function. + + >>> from spikeinterface.extractors.neoextractors.neuroexplorer import NeuroExplorerRecordingExtractor + >>> from spikeinterface.core import aggregate_channels + >>> + >>> file_path="/home/heberto/spikeinterface_datasets/ephy_testing_data/neuroexplorer/File_neuroexplorer_1.nex" + >>> + >>> streams = NeuroExplorerRecordingExtractor.get_streams(file_path=file_path) + >>> stream_names = streams[0] + >>> + >>> your_signal_stream_names = "Here goes the logic to filter from stream names the ones that you know have the same sampling rate and you want to aggregate" + >>> + >>> recording_list = [NeuroExplorerRecordingExtractor(file_path=file_path, stream_name=stream_name) for stream_name in your_signal_stream_names] + >>> recording = aggregate_channels(recording_list) + + + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + stream_id: str, optional + If there are several streams, specify the stream id you want to load. + For this neo reader streams are defined by their sampling frequency. + stream_name: str, optional + If there are several streams, specify the stream name you want to load. + all_annotations: bool, default: False + Load exhaustively all annotations from neo. + """ + + mode = "file" + NeoRawIOClass = "NeuroExplorerRawIO" + name = "neuroexplorer" + + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + neo_kwargs = {"filename": str(file_path)} + NeoBaseRecordingExtractor.__init__( + self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + ) + self._kwargs.update({"file_path": str(Path(file_path).absolute())}) + self.extra_requirements.append("neo[edf]") + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +read_neuroexplorer = define_function_from_class(source_class=NeuroExplorerRecordingExtractor, name="read_neuroexplorer") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 900bdec06e..a62c81fc00 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -109,6 +109,17 @@ class NeuroScopeRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +class NeuroExplorerRecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = NeuroExplorerRecordingExtractor + downloads = ["neuroexplorer"] + entities = [ + ("neuroexplorer/File_neuroexplorer_1.nex", {"stream_name": "ContChannel01"}), + ("neuroexplorer/File_neuroexplorer_1.nex", {"stream_name": "ContChannel02"}), + ("neuroexplorer/File_neuroexplorer_2.nex", {"stream_name": "ContChannel01"}), + ("neuroexplorer/File_neuroexplorer_2.nex", {"stream_name": "ContChannel02"}), + ] + + class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuroScopeSortingExtractor downloads = ["neuroscope"] From 842de8db06b6268651e382af81e1dfb3cfb06822 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 07:31:11 -0400 Subject: [PATCH 30/73] comment out mountainsort4/5 fail on pip install for mac --- installation_tips/full_spikeinterface_environment_mac.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/installation_tips/full_spikeinterface_environment_mac.yml b/installation_tips/full_spikeinterface_environment_mac.yml index 867de7f98b..7ce4a149cc 100755 --- a/installation_tips/full_spikeinterface_environment_mac.yml +++ b/installation_tips/full_spikeinterface_environment_mac.yml @@ -35,5 +35,5 @@ dependencies: - spikeinterface-gui - tridesclous>=1.6.8 # - phy==2.0b5 - - mountainsort4>=1.0.0 - - mountainsort5>=0.3.0 + # - mountainsort4>=1.0.0 isosplit5 fails on pip install for mac + # - mountainsort5>=0.3.0 From 2596d122df5cf368854e3806a684d5ca4f3adcff Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 Sep 2023 15:23:15 +0200 Subject: [PATCH 31/73] improve generate --- src/spikeinterface/core/generate.py | 57 ++++++++++++++++------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 93b9459b5f..617a39b6bc 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -2,6 +2,7 @@ import numpy as np from typing import Union, Optional, List, Literal +import warnings from .numpyextractors import NumpyRecording, NumpySorting @@ -31,7 +32,7 @@ def generate_recording( set_probe: Optional[bool] = True, ndim: Optional[int] = 2, seed: Optional[int] = None, - mode: Literal["lazy", "legacy"] = "legacy", + mode: Literal["lazy", "legacy"] = "lazy", ) -> BaseRecording: """ Generate a recording object. @@ -51,10 +52,10 @@ def generate_recording( The number of dimensions of the probe, by default 2. Set to 3 to make 3 dimensional probes. seed : Optional[int] A seed for the np.ramdom.default_rng function - mode: str ["lazy", "legacy"] Default "legacy". + mode: str ["lazy", "legacy"] Default "lazy". "legacy": generate a NumpyRecording with white noise. - This mode is kept for backward compatibility and will be deprecated in next release. - "lazy": return a NoiseGeneratorRecording + This mode is kept for backward compatibility and will be deprecated version 0.100.0. + "lazy": return a NoiseGeneratorRecording instance. Returns ------- @@ -64,6 +65,10 @@ def generate_recording( seed = _ensure_seed(seed) if mode == "legacy": + warnings.warn( + "generate_recording() : mode='legacy' will be deprecated in version 0.100.0. Use mode='lazy' instead.", + DeprecationWarning, + ) recording = _generate_recording_legacy(num_channels, sampling_frequency, durations, seed) elif mode == "lazy": recording = NoiseGeneratorRecording( @@ -460,7 +465,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol class NoiseGeneratorRecording(BaseRecording): """ - A lazy recording that generates random samples if and only if `get_traces` is called. + A lazy recording that generates white noise samples if and only if `get_traces` is called. This done by tiling small noise chunk. @@ -477,7 +482,7 @@ class NoiseGeneratorRecording(BaseRecording): The sampling frequency of the recorder. durations : List[float] The durations of each segment in seconds. Note that the length of this list is the number of segments. - noise_level: float, default 5: + noise_level: float, default 1: Std of the white noise dtype : Optional[Union[np.dtype, str]], default='float32' The dtype of the recording. Note that only np.float32 and np.float64 are supported. @@ -503,7 +508,7 @@ def __init__( num_channels: int, sampling_frequency: float, durations: List[float], - noise_level: float = 5.0, + noise_level: float = 1.0, dtype: Optional[Union[np.dtype, str]] = "float32", seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", @@ -569,7 +574,7 @@ def __init__( if self.strategy == "tile_pregenerated": rng = np.random.default_rng(seed=self.seed) self.noise_block = ( - rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) * noise_level + rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) * noise_level ) elif self.strategy == "on_the_fly": pass @@ -586,35 +591,35 @@ def get_traces( start_frame = 0 if start_frame is None else max(start_frame, 0) end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples) - start_frame_mod = start_frame % self.noise_block_size - end_frame_mod = end_frame % self.noise_block_size + start_frame_within_block = start_frame % self.noise_block_size + end_frame_within_block = end_frame % self.noise_block_size num_samples = end_frame - start_frame traces = np.empty(shape=(num_samples, self.num_channels), dtype=self.dtype) - start_block_index = start_frame // self.noise_block_size - end_block_index = end_frame // self.noise_block_size + first_block_index = start_frame // self.noise_block_size + last_block_index = end_frame // self.noise_block_size pos = 0 - for block_index in range(start_block_index, end_block_index + 1): + for block_index in range(first_block_index, last_block_index + 1): if self.strategy == "tile_pregenerated": noise_block = self.noise_block elif self.strategy == "on_the_fly": rng = np.random.default_rng(seed=(self.seed, block_index)) - noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) + noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) noise_block *= self.noise_level - if block_index == start_block_index: - if start_block_index != end_block_index: - end_first_block = self.noise_block_size - start_frame_mod - traces[:end_first_block] = noise_block[start_frame_mod:] + if block_index == first_block_index: + if first_block_index != last_block_index: + end_first_block = self.noise_block_size - start_frame_within_block + traces[:end_first_block] = noise_block[start_frame_within_block:] pos += end_first_block else: # special case when unique block - traces[:] = noise_block[start_frame_mod : start_frame_mod + traces.shape[0]] - elif block_index == end_block_index: - if end_frame_mod > 0: - traces[pos:] = noise_block[:end_frame_mod] + traces[:] = noise_block[start_frame_within_block : start_frame_within_block + num_samples] + elif block_index == last_block_index: + if end_frame_within_block > 0: + traces[pos:] = noise_block[:end_frame_within_block] else: traces[pos : pos + self.noise_block_size] = noise_block pos += self.noise_block_size @@ -632,7 +637,7 @@ def get_traces( def generate_recording_by_size( full_traces_size_GiB: float, - num_channels: int = 1024, + num_channels: int = 384, seed: Optional[int] = None, strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated", ) -> NoiseGeneratorRecording: @@ -641,7 +646,7 @@ def generate_recording_by_size( This is a convenience wrapper around the NoiseGeneratorRecording class where only the size in GiB (NOT GB!) is specified. - It is generated with 1024 channels and a sampling frequency of 1 Hz. The duration is manipulted to + It is generated with 384 channels and a sampling frequency of 1 Hz. The duration is manipulted to produced the desired size. Seee GeneratorRecording for more details. @@ -649,7 +654,7 @@ def generate_recording_by_size( Parameters ---------- full_traces_size_GiB : float - The size in gibibyte (GiB) of the recording. + The size in gigabytes (GiB) of the recording. num_channels: int Number of channels. seed : int, optional @@ -662,7 +667,7 @@ def generate_recording_by_size( dtype = np.dtype("float32") sampling_frequency = 30_000.0 # Hz - num_channels = 1024 + num_channels = 384 GiB_to_bytes = 1024**3 full_traces_size_bytes = int(full_traces_size_GiB * GiB_to_bytes) From df226a18a62c897a7a9399afd291a84b6a4c7a76 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 Sep 2023 16:57:05 +0200 Subject: [PATCH 32/73] normalize scale --- .../preprocessing/tests/test_normalize_scale.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index b62a73a8cb..08ca56cbec 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -78,13 +78,18 @@ def test_zscore(): assert np.all(np.abs(np.mean(tr, axis=0)) < 0.01) assert np.all(np.abs(np.std(tr, axis=0) - 1) < 0.01) + +def test_zscore_int(): + seed = 0 + rec = generate_recording(seed=seed, mode="lazy") rec_int = scale(rec, dtype="int16", gain=100) with pytest.raises(AssertionError): - rec4 = zscore(rec_int, dtype=None) - rec4 = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) - tr = rec4.get_traces(segment_index=0) - trace_mean = np.mean(tr, axis=0) - trace_std = np.std(tr, axis=0) + zscore(rec_int, dtype=None) + + zscore_recording = zscore(rec_int, dtype="int16", int_scale=256, mode="mean+std", seed=seed) + traces = zscore_recording.get_traces(segment_index=0) + trace_mean = np.mean(traces, axis=0) + trace_std = np.std(traces, axis=0) assert np.all(np.abs(trace_mean) < 1) assert np.all(np.abs(trace_std - 256) < 1) From 09394787e4f15c4268bc0ef8d9883e02163348ee Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 Sep 2023 18:02:58 +0200 Subject: [PATCH 33/73] not test for bug right now --- src/spikeinterface/preprocessing/normalize_scale.py | 2 +- .../preprocessing/tests/test_normalize_scale.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/normalize_scale.py b/src/spikeinterface/preprocessing/normalize_scale.py index 90b39aee8a..7d43982853 100644 --- a/src/spikeinterface/preprocessing/normalize_scale.py +++ b/src/spikeinterface/preprocessing/normalize_scale.py @@ -293,7 +293,7 @@ def __init__( means = means[None, :] stds = np.std(random_data, axis=0) stds = stds[None, :] - gain = 1 / stds + gain = 1.0 / stds offset = -means / stds if int_scale is not None: diff --git a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py index 08ca56cbec..764acc9852 100644 --- a/src/spikeinterface/preprocessing/tests/test_normalize_scale.py +++ b/src/spikeinterface/preprocessing/tests/test_normalize_scale.py @@ -80,8 +80,8 @@ def test_zscore(): def test_zscore_int(): - seed = 0 - rec = generate_recording(seed=seed, mode="lazy") + seed = 1 + rec = generate_recording(seed=seed, mode="legacy") rec_int = scale(rec, dtype="int16", gain=100) with pytest.raises(AssertionError): zscore(rec_int, dtype=None) From 55e8c51d16dfc2be6ca3dadc18d630fa9507fda1 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:14:20 -0400 Subject: [PATCH 34/73] update linux and draft a github action --- .github/workflows/installation-tips-test.yml | 43 +++++++++++++++++++ ...spikeinterface_environment_linux_dandi.yml | 11 ++--- 2 files changed, 47 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/installation-tips-test.yml diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml new file mode 100644 index 0000000000..efee9ef370 --- /dev/null +++ b/.github/workflows/installation-tips-test.yml @@ -0,0 +1,43 @@ +name: Creates Conda Install for Installation Tips + +on: + workflow_dispatch: + pull_request: + types: [synchronize, opened, reopened] + schedule: + - cron: "0 12 * * *" # Daily at noon UTC + +jobs: + testing: + name: Build Conda Env on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + label: linux_dandi + - os: macos-latest + label: mac + - os: windows-latest + label: windows + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + + - name: Test Conda Environment Creation + uses: conda-incubator/setup-miniconda@v2.2.0 + with: + environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml + + \ No newline at end of file diff --git a/installation_tips/full_spikeinterface_environment_linux_dandi.yml b/installation_tips/full_spikeinterface_environment_linux_dandi.yml index d402f6805f..2ed176b16c 100755 --- a/installation_tips/full_spikeinterface_environment_linux_dandi.yml +++ b/installation_tips/full_spikeinterface_environment_linux_dandi.yml @@ -3,13 +3,11 @@ channels: - conda-forge - defaults dependencies: - - python=3.9 + - python=3.10 - pip>=21.0 - mamba - # numpy 1.22 break numba which break tridesclous - numpy<1.22 - # joblib 1.2 is breaking hdbscan - - joblib=1.1 + - joblib - tqdm - matplotlib - h5py @@ -31,12 +29,11 @@ dependencies: - ipympl - pip: - ephyviewer - - neo>=0.11 - - elephant>=0.10.0 + - neo>=0.12 - probeinterface>=0.2.11 - MEArec>=1.8 - spikeinterface[full, widgets] - spikeinterface-gui - - tridesclous>=1.6.6.1 + - tridesclous>=1.6.8 - spyking-circus>=1.1.0 # - phy==2.0b5 From 807f0ad3e4e93c524eac28e376e2253b72e6aa0f Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:22:48 -0400 Subject: [PATCH 35/73] test on push for github actions --- installation-tips-test.yml | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 installation-tips-test.yml diff --git a/installation-tips-test.yml b/installation-tips-test.yml new file mode 100644 index 0000000000..2628b08529 --- /dev/null +++ b/installation-tips-test.yml @@ -0,0 +1,40 @@ +name: Creates Conda Install for Installation Tips + +on: + workflow_dispatch: + push: + schedule: + - cron: "0 12 * * *" # Daily at noon UTC + +jobs: + testing: + name: Build Conda Env on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-latest + label: linux_dandi + - os: macos-latest + label: mac + - os: windows-latest + label: windows + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + + - name: Test Conda Environment Creation + uses: conda-incubator/setup-miniconda@v2.2.0 + with: + environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml From 6a37b0da0c9483171893bfe9c3e89ea51af35a44 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:29:36 -0400 Subject: [PATCH 36/73] fix misaligned run --- installation-tips-test.yml | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/installation-tips-test.yml b/installation-tips-test.yml index 2628b08529..c6c32b812b 100644 --- a/installation-tips-test.yml +++ b/installation-tips-test.yml @@ -7,7 +7,7 @@ on: - cron: "0 12 * * *" # Daily at noon UTC jobs: - testing: + installation-tips-testing: name: Build Conda Env on ${{ matrix.os }} OS runs-on: ${{ matrix.os }} defaults: @@ -25,15 +25,9 @@ jobs: label: windows steps: - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.10' - - name: Add conda to system path - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - echo $CONDA/bin >> $GITHUB_PATH - - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: From 918cd67c699eb7ce353a90d6d68408a4b0734e8b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:32:05 -0400 Subject: [PATCH 37/73] fix my folder mistake --- .github/workflows/installation-tips-test.yml | 15 ++------- installation-tips-test.yml | 34 -------------------- 2 files changed, 3 insertions(+), 46 deletions(-) delete mode 100644 installation-tips-test.yml diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index efee9ef370..c6c32b812b 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -2,13 +2,12 @@ name: Creates Conda Install for Installation Tips on: workflow_dispatch: - pull_request: - types: [synchronize, opened, reopened] + push: schedule: - cron: "0 12 * * *" # Daily at noon UTC jobs: - testing: + installation-tips-testing: name: Build Conda Env on ${{ matrix.os }} OS runs-on: ${{ matrix.os }} defaults: @@ -26,18 +25,10 @@ jobs: label: windows steps: - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: '3.10' - - name: Add conda to system path - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - echo $CONDA/bin >> $GITHUB_PATH - - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml - - \ No newline at end of file diff --git a/installation-tips-test.yml b/installation-tips-test.yml deleted file mode 100644 index c6c32b812b..0000000000 --- a/installation-tips-test.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Creates Conda Install for Installation Tips - -on: - workflow_dispatch: - push: - schedule: - - cron: "0 12 * * *" # Daily at noon UTC - -jobs: - installation-tips-testing: - name: Build Conda Env on ${{ matrix.os }} OS - runs-on: ${{ matrix.os }} - defaults: - run: - shell: bash -el {0} - strategy: - fail-fast: false - matrix: - include: - - os: ubuntu-latest - label: linux_dandi - - os: macos-latest - label: mac - - os: windows-latest - label: windows - steps: - - uses: actions/checkout@v3 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - name: Test Conda Environment Creation - uses: conda-incubator/setup-miniconda@v2.2.0 - with: - environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml From b815ed91c0a8b2f52a44fb9f83b45eb2987184cf Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:35:20 -0400 Subject: [PATCH 38/73] fix spacing of action --- .github/workflows/installation-tips-test.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index c6c32b812b..7285dfa9f5 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -28,7 +28,7 @@ jobs: uses: actions/setup-python@v4 with: python-version: '3.10' - - name: Test Conda Environment Creation - uses: conda-incubator/setup-miniconda@v2.2.0 - with: + - name: Test Conda Environment Creation + uses: conda-incubator/setup-miniconda@v2.2.0 + with: environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml From 2bbead38f3309428bceb4819c9328b64501915ae Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:36:38 -0400 Subject: [PATCH 39/73] final fix I hope --- .github/workflows/installation-tips-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index 7285dfa9f5..c3c70db66e 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -25,7 +25,7 @@ jobs: label: windows steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 + - uses: actions/setup-python@v4 with: python-version: '3.10' - name: Test Conda Environment Creation From 03f9da463901b9a33a6c27b50d93e80d569c0f4a Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:40:10 -0400 Subject: [PATCH 40/73] delete on push, deleting trailing white --- .github/workflows/installation-tips-test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index c3c70db66e..a985a6d60d 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -2,7 +2,6 @@ name: Creates Conda Install for Installation Tips on: workflow_dispatch: - push: schedule: - cron: "0 12 * * *" # Daily at noon UTC @@ -31,4 +30,4 @@ jobs: - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: - environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml + environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml \ No newline at end of file From cf7ee82752bdb32d6960676f0599651eee52b98b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Thu, 7 Sep 2023 05:38:18 -0400 Subject: [PATCH 41/73] local pre-commit --- .github/workflows/installation-tips-test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index a985a6d60d..64a4e45270 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: schedule: - cron: "0 12 * * *" # Daily at noon UTC - + jobs: installation-tips-testing: name: Build Conda Env on ${{ matrix.os }} OS @@ -30,4 +30,4 @@ jobs: - name: Test Conda Environment Creation uses: conda-incubator/setup-miniconda@v2.2.0 with: - environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml \ No newline at end of file + environment-file: ./installations_tips/full_spikeinterface_environment_${{ matrix.label }}.yml From 34538311d1c00f5f45b4cb84a03984efa9ef4f3d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 7 Sep 2023 18:38:25 +0200 Subject: [PATCH 42/73] fedeback from alessio and ramon --- src/spikeinterface/core/waveform_tools.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 252ea68738..53c0df68df 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -258,8 +258,8 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _worker_ditribute_buffers - init_func = _init_worker_ditribute_buffers + func = _worker_distribute_buffers + init_func = _init_worker_distribute_buffers init_args = ( recording, @@ -283,7 +283,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_ditribute_buffers( +def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker @@ -329,7 +329,7 @@ def _init_worker_ditribute_buffers( # used by ChunkRecordingExecutor -def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -480,7 +480,7 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"all_waveforms.npy") + filename = str(folder / f"waveforms.npy") all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) wf_array_info = filename elif mode == "shared_memory": @@ -497,15 +497,10 @@ def extract_waveforms_to_single_buffer( job_kwargs = fix_job_kwargs(job_kwargs) - inds_by_unit = {} - for unit_ind, unit_id in enumerate(unit_ids): - (inds,) = np.nonzero(spikes["unit_index"] == unit_ind) - inds_by_unit[unit_id] = inds - if num_spikes > 0: # and run - func = _worker_ditribute_single_buffer - init_func = _init_worker_ditribute_single_buffer + func = _worker_distribute_single_buffer + init_func = _init_worker_distribute_single_buffer init_args = ( recording, @@ -537,7 +532,7 @@ def extract_waveforms_to_single_buffer( return all_waveforms, wf_array_info -def _init_worker_ditribute_single_buffer( +def _init_worker_distribute_single_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} @@ -576,7 +571,7 @@ def _init_worker_ditribute_single_buffer( # used by ChunkRecordingExecutor -def _worker_ditribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] From d79dbe26bb1d3f5db45b1ac36d84f7b96be08f18 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 7 Sep 2023 18:51:26 +0200 Subject: [PATCH 43/73] extract_waveforms_to_single_buffer change folder to file_path --- .../core/tests/test_waveform_tools.py | 27 ++++++++++++------- src/spikeinterface/core/waveform_tools.py | 17 +++++------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 52d7472c92..1d7e38832a 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -59,14 +59,15 @@ def test_waveform_tools(): ] some_modes = [ {"mode": "memmap"}, + {"mode": "shared_memory"}, ] - if platform.system() != "Windows": - # shared memory on windows is buggy... - some_modes.append( - { - "mode": "shared_memory", - } - ) + # if platform.system() != "Windows": + # # shared memory on windows is buggy... + # some_modes.append( + # { + # "mode": "shared_memory", + # } + # ) some_sparsity = [ dict(sparsity_mask=None), @@ -87,9 +88,11 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir(parents=True) - mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) - else: - mode_kwargs_ = mode_kwargs + wf_file_path = wf_folder / "waveforms_all_units.npy" + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["folder" ] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -113,6 +116,10 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["file_path" ] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 53c0df68df..c363ac49dc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -406,7 +406,7 @@ def extract_waveforms_to_single_buffer( nafter, mode="memmap", return_scaled=False, - folder=None, + file_path=None, dtype=None, sparsity_mask=None, copy=False, @@ -442,8 +442,8 @@ def extract_waveforms_to_single_buffer( Mode to use ('memmap' | 'shared_memory') return_scaled: bool Scale traces before exporting to buffer or not. - folder: str or path - In case of memmap mode, folder to save npy files + file_path: str or path + In case of memmap mode, file to save npy file. dtype: numpy.dtype dtype for waveforms buffer sparsity_mask: None or array of bool @@ -468,9 +468,9 @@ def extract_waveforms_to_single_buffer( dtype = np.dtype(dtype) if mode == "shared_memory": - assert folder is None + assert file_path is None else: - folder = Path(folder) + file_path = Path(file_path) num_spikes = spikes.size if sparsity_mask is None: @@ -480,9 +480,8 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"waveforms.npy") - all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wf_array_info = filename + all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) + wf_array_info = str(file_path) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -538,7 +537,6 @@ def _init_worker_distribute_single_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter @@ -574,7 +572,6 @@ def _init_worker_distribute_single_buffer( def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"] From 5c6615975704bd9bbbda722291d04ff0ccfb3c90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:51:54 +0000 Subject: [PATCH 44/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_waveform_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 1d7e38832a..e9cf1bfb5f 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -92,7 +92,7 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs) if mode_kwargs["mode"] == "memmap": - mode_kwargs_["folder" ] = wf_folder + mode_kwargs_["folder"] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -118,8 +118,8 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs) if mode_kwargs["mode"] == "memmap": - mode_kwargs_["file_path" ] = wf_file_path - + mode_kwargs_["file_path"] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, From 7080696f12617bfd08769c89a8471768878191ca Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 8 Sep 2023 10:17:48 +0200 Subject: [PATCH 45/73] Update src/spikeinterface/core/waveform_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/waveform_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index c363ac49dc..a63d0a80b7 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -417,7 +417,7 @@ def extract_waveforms_to_single_buffer( Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is - needed to recover waveforms unit by unit. Importantly in case of sparsity, the channel are not aligned across + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. Important note: for the "shared_memory" mode wf_array_info contains reference to From 99602f17ed4793bbf21b577cd8f87860cc3c3c2b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 11 Sep 2023 10:46:51 +0200 Subject: [PATCH 46/73] Make plexon2 tests conditional on Wine dependency (on Linux) --- .../extractors/tests/test_neoextractors.py | 40 ++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index da162eccf1..5fe42b0c4e 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -1,5 +1,6 @@ import unittest -from platform import python_version +import platform +import subprocess from packaging import version import pytest @@ -18,6 +19,38 @@ local_folder = get_global_dataset_folder() / "ephy_testing_data" +def has_plexon2_dependencies(): + """ + Check if required Plexon2 dependencies are installed on different OS. + """ + + os_type = platform.system() + + if os_type == "Windows": + # On Windows, no need for additional dependencies + return True + + elif os_type == "Linux": + # Check for 'wine' using dpkg + try: + result_wine = subprocess.run( + ["dpkg", "-l", "wine"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + except subprocess.CalledProcessError: + return False + + # Check for 'zugbruecke' using pip + try: + import zugbruecke + + return True + except ImportError: + return False + else: + # Not sure about MacOS + raise ValueError(f"Unsupported OS: {os_type}") + + class MearecRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MEArecRecordingExtractor downloads = ["mearec"] @@ -218,7 +251,7 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif( - version.parse(python_version()) >= version.parse("3.10"), + version.parse(platform.python_version()) >= version.parse("3.10"), reason="Sonpy only testing with Python < 3.10!", ) class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): @@ -291,6 +324,7 @@ def test_pickling(self): # We mark plexon2 tests as they require additional dependencies (wine) +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") @pytest.mark.plexon2 class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor @@ -300,6 +334,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") @pytest.mark.plexon2 class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor @@ -309,6 +344,7 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") @pytest.mark.plexon2 class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor From c974ac034cfdb2b54f76e4ae7910e8f8957e6591 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 11 Sep 2023 11:06:14 +0200 Subject: [PATCH 47/73] Update src/spikeinterface/extractors/neoextractors/neuroexplorer.py Co-authored-by: Alessio Buccino --- src/spikeinterface/extractors/neoextractors/neuroexplorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index e936d91fbf..0be65dd5cb 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -15,7 +15,7 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): This is because the NeuroExplorerRawIO class does not support multi-channel recordings as in the NeuroExplorer format they might have different sampling rates. - Consider exctracting all the channels and then concatenating them with the concatenate_recordings function. + Consider exctracting all the channels and then concatenating them with the aggregate_channels function. >>> from spikeinterface.extractors.neoextractors.neuroexplorer import NeuroExplorerRecordingExtractor >>> from spikeinterface.core import aggregate_channels From e733afe6eb103d954eea1e8992fac02f72bb51ba Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 11 Sep 2023 11:08:00 +0200 Subject: [PATCH 48/73] Update src/spikeinterface/extractors/neoextractors/neuroexplorer.py --- src/spikeinterface/extractors/neoextractors/neuroexplorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index 0be65dd5cb..b430e45232 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -20,7 +20,7 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): >>> from spikeinterface.extractors.neoextractors.neuroexplorer import NeuroExplorerRecordingExtractor >>> from spikeinterface.core import aggregate_channels >>> - >>> file_path="/home/heberto/spikeinterface_datasets/ephy_testing_data/neuroexplorer/File_neuroexplorer_1.nex" + >>> file_path="/the/path/to/your/nex/file.nex" >>> >>> streams = NeuroExplorerRecordingExtractor.get_streams(file_path=file_path) >>> stream_names = streams[0] From d7d34c4c676e49724da0c79aab2c7605865473bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:49:24 +0000 Subject: [PATCH 49/73] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 23.7.0 → 23.9.1](https://github.com/psf/black/compare/23.7.0...23.9.1) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ced1ee6a2f..07601cd208 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black files: ^src/ From e5a523c9263fa1a229e89905639496da03dd39e0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 10:12:54 +0200 Subject: [PATCH 50/73] Improvement after Ramon comments. --- src/spikeinterface/core/waveform_tools.py | 67 +++++++++++------------ 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a63d0a80b7..6e0d6f412b 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -350,16 +350,10 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -420,6 +414,9 @@ def extract_waveforms_to_single_buffer( needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. + Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. + This ensures that spikes.shape[0] == all_waveforms.shape[0]. + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -449,10 +446,13 @@ def extract_waveforms_to_single_buffer( sparsity_mask: None or array of bool If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool - If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then arrays_info is also return. Please keep in mind that arrays_info - need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. - Also when copy=False the SharedMemory will need to be unlink manually + If True (default), the output shared memory object is copied to a numpy standard array and no reference + to the internal shared memory object is kept. + If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object + need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation + faults which are hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired. + {} Returns @@ -481,7 +481,8 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - wf_array_info = str(file_path) + # wf_array_info = str(file_path) + wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -490,7 +491,8 @@ def extract_waveforms_to_single_buffer( else: all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name - wf_array_info = (shm, shm_name, dtype.str, shape) + # wf_array_info = (shm, shm_name, dtype.str, shape) + wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) else: raise ValueError("allocate_waveforms_buffers bad mode") @@ -503,7 +505,6 @@ def extract_waveforms_to_single_buffer( init_args = ( recording, - unit_ids, spikes, wf_array_info, nbefore, @@ -532,7 +533,7 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( - recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask + recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} worker_ctx["recording"] = recording @@ -545,13 +546,13 @@ def _init_worker_distribute_single_buffer( worker_ctx["mode"] = mode if mode == "memmap": - filename = wf_array_info + filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - shm, shm_name, dtype, shape = wf_array_info + shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm @@ -587,16 +588,10 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -611,17 +606,17 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for spike_ind in range(l0, l1): - sample_index = spikes[spike_ind]["sample_index"] - unit_index = spikes[spike_ind]["unit_index"] + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - all_waveforms[spike_ind, :, :] = wf + all_waveforms[spike_index, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, : wf.shape[1]] = wf + all_waveforms[spike_index, :, : wf.shape[1]] = wf if worker_ctx["mode"] == "memmap": all_waveforms.flush() @@ -642,7 +637,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None sparsity_mask : None or numpy array Optionally the boolean sparsity mask folder : None or str or Path - If a folde ri sgiven all + If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming. Returns ------- From e37b0515742b129984eb75da35c869a1de6b78d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Sep 2023 08:13:32 +0000 Subject: [PATCH 51/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_tools.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 6e0d6f412b..39623da329 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -354,7 +354,6 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) - # slice in absolut in spikes vector l0 = i0 + s0 l1 = i1 + s0 @@ -416,7 +415,7 @@ def extract_waveforms_to_single_buffer( Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. This ensures that spikes.shape[0] == all_waveforms.shape[0]. - + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -481,7 +480,7 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - # wf_array_info = str(file_path) + # wf_array_info = str(file_path) wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: @@ -592,7 +591,6 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) - # slice in absolut in spikes vector l0 = i0 + s0 l1 = i1 + s0 From 45f2b15b286e5b071cf92ec5f18257e3a641e332 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 11:41:34 +0200 Subject: [PATCH 52/73] Feeback from Alessio --- src/spikeinterface/core/node_pipeline.py | 21 +++++++++---------- .../core/tests/test_node_pipeline.py | 9 ++++---- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 610ae42398..64949357c4 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -139,21 +139,20 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): # this is not implemented yet this will be done in separted PR class SpikeRetriever(PeakSource): """ - This class is usefull to inject a sorting object in the node pipepline mechanisim. + This class is useful to inject a sorting object in the node pipepline mechanism. It allows to compute some post processing with the same machinery used for sorting components. - This is a first step to totaly refactor: + This is used by: * compute_spike_locations() * compute_amplitude_scalings() * compute_spike_amplitudes() * compute_principal_components() - - recording: - - sorting: - - channel_from_template: bool (default True) - If True then the channel_index is infered from template and extremum_channel_inds must be provided. + recording : BaseRecording + The recording object. + sorting: BaseSorting + The sorting object. + channel_from_template: bool, default: True + If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False every spikes compute its own channel index given a radius around the template max channel. extremum_channel_inds: dict of int The extremum channel index dict given from template. @@ -174,7 +173,7 @@ def __init__( assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" - self.peaks = sorting_to_peak(sorting, extremum_channel_inds) + self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) if not channel_from_template: channel_distance = get_channel_distances(recording) @@ -223,7 +222,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) -def sorting_to_peak(sorting, extremum_channel_inds): +def sorting_to_peaks(sorting, extremum_channel_inds): spikes = sorting.to_spike_vector() peaks = np.zeros(spikes.size, dtype=base_peak_dtype) peaks["sample_index"] = spikes["sample_index"] diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index d0d49b865c..bcb15b6455 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -15,7 +15,7 @@ SpikeRetriever, PipelineNode, ExtractDenseWaveforms, - sorting_to_peak, + sorting_to_peaks, ) @@ -72,15 +72,14 @@ def compute(self, traces, peaks, waveforms): def test_run_node_pipeline(): recording, sorting = generate_ground_truth_recording(num_channels=10, num_units=10, durations=[10.0]) - # job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) - job_kwargs = dict(chunk_duration="0.5s", n_jobs=1, progress_bar=False) + job_kwargs = dict(chunk_duration="0.5s", n_jobs=2, progress_bar=False) spikes = sorting.to_spike_vector() # create peaks from spikes we = extract_waveforms(recording, sorting, mode="memory", **job_kwargs) extremum_channel_inds = get_template_extremum_channel(we, peak_sign="neg", outputs="index") - peaks = sorting_to_peak(sorting, extremum_channel_inds) + peaks = sorting_to_peaks(sorting, extremum_channel_inds) peak_retriever = PeakRetriever(recording, peaks) # channel index is from template @@ -97,7 +96,7 @@ def test_run_node_pipeline(): peak_sign="neg", ) - # test with 2 diffrents first node + # test with 3 differents first nodes for loop, peak_source in enumerate((peak_retriever, spike_retriever_T, spike_retriever_S)): # one step only : squeeze output nodes = [ From 4ea63ef901fb9d05308179b3d465e85dc777ab16 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 16:54:37 +0200 Subject: [PATCH 53/73] Update src/spikeinterface/extractors/neoextractors/neuroexplorer.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/extractors/neoextractors/neuroexplorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py index b430e45232..2c8603cb9c 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroexplorer.py +++ b/src/spikeinterface/extractors/neoextractors/neuroexplorer.py @@ -15,7 +15,7 @@ class NeuroExplorerRecordingExtractor(NeoBaseRecordingExtractor): This is because the NeuroExplorerRawIO class does not support multi-channel recordings as in the NeuroExplorer format they might have different sampling rates. - Consider exctracting all the channels and then concatenating them with the aggregate_channels function. + Consider extracting all the channels and then concatenating them with the aggregate_channels function. >>> from spikeinterface.extractors.neoextractors.neuroexplorer import NeuroExplorerRecordingExtractor >>> from spikeinterface.core import aggregate_channels From 84f7e21d1a48385a0f4c86ee886eea793325bb09 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 12 Sep 2023 17:51:17 +0200 Subject: [PATCH 54/73] add to the API --- doc/api.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index 2e9fc1567a..fdef00c928 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -101,6 +101,8 @@ NEO-based .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt + .. autofunction:: read_neuroexplorer + Non-NEO-based ~~~~~~~~~~~~~ From 966d56a9fa150472e43f888ad4e6b62f89a77ef6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 08:18:51 +0200 Subject: [PATCH 55/73] doc --- src/spikeinterface/core/waveform_tools.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 39623da329..da8e3d64b6 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -417,9 +417,11 @@ def extract_waveforms_to_single_buffer( This ensures that spikes.shape[0] == all_waveforms.shape[0]. Important note: for the "shared_memory" mode wf_array_info contains reference to - the shared memmory buffer, this variable must be reference as long as arrays as used. - And this variable is also returned. - To avoid this a copy to non shared memmory can be perform at the end. + the shared memmory buffer, this variable must be referenced as long as arrays is used. + This variable must also unlink() when the array is de-referenced. + To avoid this complicated behavior, by default (copy=True) the shared memmory buffer is copied into a standard + numpy array. + Parameters ---------- From e73cf7e107026d80b176e9fb420b31cea964b730 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 10:27:39 +0200 Subject: [PATCH 56/73] Simplify plexon2 tests (only run when dependencies are installed) --- .github/workflows/full-test-with-codecov.yml | 2 +- .github/workflows/full-test.yml | 5 +---- src/spikeinterface/extractors/tests/test_neoextractors.py | 5 +---- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index d0bf109a00..a5561c2ffc 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -54,7 +54,7 @@ jobs: - name: run tests run: | source ${{ github.workspace }}/test_env/bin/activate - pytest -m "not sorters_external and not plexon2" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 + pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index a343500c08..8f88e84039 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -133,7 +133,7 @@ jobs: run: ./.github/run_tests.sh core - name: Test extractors if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} - run: ./.github/run_tests.sh "extractors and not streaming_extractors and not plexon2" + run: ./.github/run_tests.sh "extractors and not streaming_extractors" - name: Test preprocessing if: ${{ steps.modules-changed.outputs.PREPROCESSING_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh preprocessing @@ -164,6 +164,3 @@ jobs: - name: Test internal sorters if: ${{ steps.modules-changed.outputs.SORTERS_INTERNAL_CHANGED == 'true' || steps.modules-changed.outputs.SORTINGCOMPONENTS_CHANGED || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh sorters_internal - - name: Test plexon2 - if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} - run: ./.github/run_tests.sh plexon2 diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 5fe42b0c4e..ce2703d382 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -323,9 +323,8 @@ def test_pickling(self): pass -# We mark plexon2 tests as they require additional dependencies (wine) +# We run plexon2 tests only if we have dependencies (wine) @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") -@pytest.mark.plexon2 class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] @@ -335,7 +334,6 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") -@pytest.mark.plexon2 class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor downloads = ["plexon"] @@ -345,7 +343,6 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") -@pytest.mark.plexon2 class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] From e3b96d3f4aa8cfa63b4703726950e81d86aa43df Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 10:52:19 +0200 Subject: [PATCH 57/73] Relax check_borders (and fix typo) in InjectTemplatesRecording --- src/spikeinterface/core/generate.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index bbf77682ee..d78a2e4e57 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1,5 +1,5 @@ import math - +import warnings import numpy as np from typing import Union, Optional, List, Literal @@ -1037,13 +1037,14 @@ def __init__( parent_recording: Union[BaseRecording, None] = None, num_samples: Optional[List[int]] = None, upsample_vector: Union[List[int], None] = None, - check_borbers: bool = True, + check_borders: bool = False, ) -> None: templates = np.asarray(templates) - if check_borbers: + # TODO: this should be external to this class. It is not the responsability of this class to check the templates + if check_borders: self._check_templates(templates) - # lets test this only once so force check_borbers=false for kwargs - check_borbers = False + # lets test this only once so force check_borders=False for kwargs + check_borders = False self.templates = templates channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2])) @@ -1131,7 +1132,7 @@ def __init__( "nbefore": nbefore, "amplitude_factor": amplitude_factor, "upsample_vector": upsample_vector, - "check_borbers": check_borbers, + "check_borders": check_borders, } if parent_recording is None: self._kwargs["num_samples"] = num_samples @@ -1144,8 +1145,8 @@ def _check_templates(templates: np.ndarray): threshold = 0.01 * max_value if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold: - raise Exception( - "Warning!\nYour templates do not go to 0 on the edges in InjectTemplatesRecording.__init__\nPlease make your window bigger." + warnings.warn( + "Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger." ) From 40d304b26572c26caecefed5084c190d9a76c3ec Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 13 Sep 2023 12:14:42 +0200 Subject: [PATCH 58/73] aphabetical order in API --- doc/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index fdef00c928..7a72ead33f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -91,6 +91,7 @@ NEO-based .. autofunction:: read_mcsraw .. autofunction:: read_neuralynx .. autofunction:: read_neuralynx_sorting + .. autofunction:: read_neuroexplorer .. autofunction:: read_neuroscope .. autofunction:: read_nix .. autofunction:: read_openephys @@ -101,7 +102,6 @@ NEO-based .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx .. autofunction:: read_tdt - .. autofunction:: read_neuroexplorer Non-NEO-based From cfdc0867c2fcc5bf02b55a359abd71b91656efd0 Mon Sep 17 00:00:00 2001 From: Chetan Kandpal Date: Wed, 13 Sep 2023 16:38:23 +0530 Subject: [PATCH 59/73] Doc Changes for recently added Plexon 2 support --- doc/api.rst | 2 ++ doc/modules/extractors.rst | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index 7a72ead33f..d8990888c4 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -98,6 +98,8 @@ NEO-based .. autofunction:: read_openephys_event .. autofunction:: read_plexon .. autofunction:: read_plexon_sorting + .. autofunction:: read_plexon2 + .. autofunction:: read_plexon2_sorting .. autofunction:: read_spike2 .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index a6752e2c9d..0a0ad90ffa 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -135,7 +135,8 @@ For raw recording formats, we currently support: * **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Open Ephys Legacy** :py:func:`~spikeinterface.extractors.read_openephys()` * **Open Ephys Binary** :py:func:`~spikeinterface.extractors.read_openephys()` -* **Plexon** :py:func:`~spikeinterface.corextractorse.read_plexon()` +* **Plexon** :py:func:`~spikeinterface.extractors.read_plexon()` +* **Plexon 2** :py:func:`~spikeinterface.extractors.read_plexon2()` * **Shybrid** :py:func:`~spikeinterface.extractors.read_shybrid_recording()` * **SpikeGLX** :py:func:`~spikeinterface.extractors.read_spikeglx()` * **SpikeGLX IBL compressed** :py:func:`~spikeinterface.extractors.read_cbin_ibl()` @@ -165,6 +166,7 @@ For sorted data formats, we currently support: * **Neuralynx spikes** :py:func:`~spikeinterface.extractors.read_neuralynx_sorting()` * **NPZ (created by SpikeInterface)** :py:func:`~spikeinterface.core.read_npz_sorting()` * **Plexon spikes** :py:func:`~spikeinterface.extractors.read_plexon_sorting()` +* **Plexon 2 spikes** :py:func:`~spikeinterface.extractors.read_plexon2_sorting()` * **Shybrid** :py:func:`~spikeinterface.extractors.read_shybrid_sorting()` * **Spyking Circus** :py:func:`~spikeinterface.extractors.read_spykingcircus()` * **Trideclous** :py:func:`~spikeinterface.extractors.read_tridesclous()` From d812308d219d747d4a34ca19650636a53f533879 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 11:11:19 +0000 Subject: [PATCH 60/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index d8990888c4..43f79386e6 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -99,7 +99,7 @@ NEO-based .. autofunction:: read_plexon .. autofunction:: read_plexon_sorting .. autofunction:: read_plexon2 - .. autofunction:: read_plexon2_sorting + .. autofunction:: read_plexon2_sorting .. autofunction:: read_spike2 .. autofunction:: read_spikegadgets .. autofunction:: read_spikeglx From e2a0472d2c1c53e5d5fd58775d7e8677cf8912d7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 13:35:07 +0200 Subject: [PATCH 61/73] oups --- src/spikeinterface/core/node_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 64949357c4..14964ac7c3 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -140,7 +140,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): class SpikeRetriever(PeakSource): """ This class is useful to inject a sorting object in the node pipepline mechanism. - It allows to compute some post processing with the same machinery used for sorting components. + It allows to compute some post-processing steps with the same machinery used for sorting components. This is used by: * compute_spike_locations() * compute_amplitude_scalings() @@ -153,7 +153,7 @@ class SpikeRetriever(PeakSource): The sorting object. channel_from_template: bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. - If False every spikes compute its own channel index given a radius around the template max channel. + If False, the max channel is computed for each spike given a radius around the template max channel. extremum_channel_inds: dict of int The extremum channel index dict given from template. radius_um: float (default 50.) From ad0f05e555d1e910ec80c8759d963ca27d71bf58 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 16:39:59 +0200 Subject: [PATCH 62/73] Update src/spikeinterface/core/node_pipeline.py --- src/spikeinterface/core/node_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 14964ac7c3..b11f40a441 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -171,7 +171,7 @@ def __init__( self.channel_from_template = channel_from_template - assert extremum_channel_inds is not None, "SpikeRetriever need the dict extremum_channel_inds" + assert extremum_channel_inds is not None, "SpikeRetriever needs the extremum_channel_inds dictionary" self.peaks = sorting_to_peaks(sorting, extremum_channel_inds) From badfadc523edece56f5fa0d21abb4ddd76b11998 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 16:59:26 +0200 Subject: [PATCH 63/73] Update .github/workflows/installation-tips-test.yml --- .github/workflows/installation-tips-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/installation-tips-test.yml b/.github/workflows/installation-tips-test.yml index 64a4e45270..0e522e6baa 100644 --- a/.github/workflows/installation-tips-test.yml +++ b/.github/workflows/installation-tips-test.yml @@ -3,7 +3,7 @@ name: Creates Conda Install for Installation Tips on: workflow_dispatch: schedule: - - cron: "0 12 * * *" # Daily at noon UTC + - cron: "0 12 * * 0" # Weekly at noon UTC on Sundays jobs: installation-tips-testing: From d3987b2459f9e55e7466c8c14b27b72754346e51 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 11:32:04 -0400 Subject: [PATCH 64/73] document fixes --- doc/conf.py | 2 + doc/development/development.rst | 2 +- doc/how_to/analyse_neuropixels.rst | 64 ++++++++++---------- doc/how_to/get_started.rst | 70 +++++++++++----------- doc/how_to/handle_drift.rst | 24 ++++---- doc/modules/qualitymetrics/synchrony.rst | 2 +- doc/modules/sorters.rst | 2 +- src/spikeinterface/preprocessing/motion.py | 10 ++-- 8 files changed, 89 insertions(+), 87 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 847de9ff42..15cb65d46a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -67,6 +67,8 @@ 'numpydoc', "sphinx.ext.intersphinx", "sphinx.ext.extlinks", + "IPython.sphinxext.ipython_directive", + "IPython.sphinxext.ipython_console_highlighting" ] numpydoc_show_class_members = False diff --git a/doc/development/development.rst b/doc/development/development.rst index cd613a27e6..f1371639c3 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -1,5 +1,5 @@ Development -========== +=========== How to contribute ----------------- diff --git a/doc/how_to/analyse_neuropixels.rst b/doc/how_to/analyse_neuropixels.rst index c921b13719..37646c2146 100644 --- a/doc/how_to/analyse_neuropixels.rst +++ b/doc/how_to/analyse_neuropixels.rst @@ -4,11 +4,11 @@ Analyse Neuropixels datasets This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing. -.. code:: ipython3 +.. code:: ipython %matplotlib inline -.. code:: ipython3 +.. code:: ipython import spikeinterface.full as si @@ -16,7 +16,7 @@ including custom pre- and post-processing. import matplotlib.pyplot as plt from pathlib import Path -.. code:: ipython3 +.. code:: ipython base_folder = Path('/mnt/data/sam/DataSpikeSorting/neuropixel_example/') @@ -29,7 +29,7 @@ Read the data The ``SpikeGLX`` folder can contain several “streams” (AP, LF and NIDQ). We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython stream_names, stream_ids = si.get_neo_streams('spikeglx', spikeglx_folder) stream_names @@ -43,7 +43,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython # we do not load the sync channel, so the probe is automatically loaded raw_rec = si.read_spikeglx(spikeglx_folder, stream_name='imec0.ap', load_sync_channel=False) @@ -58,7 +58,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython # we automaticaly have the probe loaded! raw_rec.get_probe().to_dataframe() @@ -201,7 +201,7 @@ We need to specify which one to read: -.. code:: ipython3 +.. code:: ipython fig, ax = plt.subplots(figsize=(15, 10)) si.plot_probe_map(raw_rec, ax=ax, with_channel_ids=True) @@ -229,7 +229,7 @@ Let’s do something similar to the IBL destriping chain (See - instead of interpolating bad channels, we remove then. - instead of highpass_spatial_filter() we use common_reference() -.. code:: ipython3 +.. code:: ipython rec1 = si.highpass_filter(raw_rec, freq_min=400.) bad_channel_ids, channel_labels = si.detect_bad_channels(rec1) @@ -271,7 +271,7 @@ preprocessing chain wihtout to save the entire file to disk. Everything is lazy, so you can change the previsous cell (parameters, step order, …) and visualize it immediatly. -.. code:: ipython3 +.. code:: ipython # here we use static plot using matplotlib backend fig, axs = plt.subplots(ncols=3, figsize=(20, 10)) @@ -287,7 +287,7 @@ is lazy, so you can change the previsous cell (parameters, step order, .. image:: analyse_neuropixels_files/analyse_neuropixels_13_0.png -.. code:: ipython3 +.. code:: ipython # plot some channels fig, ax = plt.subplots(figsize=(20, 10)) @@ -326,7 +326,7 @@ Depending on the complexity of the preprocessing chain, this operation can take a while. However, we can make use of the powerful parallelization mechanism of SpikeInterface. -.. code:: ipython3 +.. code:: ipython job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) @@ -344,7 +344,7 @@ parallelization mechanism of SpikeInterface. write_binary_recording: 0%| | 0/1139 [00:00 0.9) -.. code:: ipython3 +.. code:: ipython keep_units = metrics.query(our_query) keep_unit_ids = keep_units.index.values @@ -1071,11 +1071,11 @@ In order to export the final results we need to make a copy of the the waveforms, but only for the selected units (so we can avoid to compute them again). -.. code:: ipython3 +.. code:: ipython we_clean = we.select_units(keep_unit_ids, new_folder=base_folder / 'waveforms_clean') -.. code:: ipython3 +.. code:: ipython we_clean @@ -1091,12 +1091,12 @@ them again). Then we export figures to a report folder -.. code:: ipython3 +.. code:: ipython # export spike sorting report to a folder si.export_report(we_clean, base_folder / 'report', format='png') -.. code:: ipython3 +.. code:: ipython we_clean = si.load_waveforms(base_folder / 'waveforms_clean') we_clean diff --git a/doc/how_to/get_started.rst b/doc/how_to/get_started.rst index a235eb4272..a923393916 100644 --- a/doc/how_to/get_started.rst +++ b/doc/how_to/get_started.rst @@ -11,7 +11,7 @@ dataset, and we will then perform some pre-processing, run a spike sorting algorithm, post-process the spike sorting output, perform curation (manual and automatic), and compare spike sorting results. -.. code:: ipython3 +.. code:: ipython import matplotlib.pyplot as plt from pprint import pprint @@ -19,7 +19,7 @@ curation (manual and automatic), and compare spike sorting results. The spikeinterface module by itself import only the spikeinterface.core submodule which is not useful for end user -.. code:: ipython3 +.. code:: ipython import spikeinterface @@ -35,7 +35,7 @@ We need to import one by one different submodules separately - ``comparison`` : comparison of spike sorting output - ``widgets`` : visualization -.. code:: ipython3 +.. code:: ipython import spikeinterface as si # import core only import spikeinterface.extractors as se @@ -56,14 +56,14 @@ This is useful for notebooks, but it is a heavier import because internally many more dependencies are imported (scipy/sklearn/networkx/matplotlib/h5py…) -.. code:: ipython3 +.. code:: ipython import spikeinterface.full as si Before getting started, we can set some global arguments for parallel processing. For this example, let’s use 4 jobs and time chunks of 1s: -.. code:: ipython3 +.. code:: ipython global_job_kwargs = dict(n_jobs=4, chunk_duration="1s") si.set_global_job_kwargs(**global_job_kwargs) @@ -75,7 +75,7 @@ Then we can open it. Note that `MEArec `__ simulated file contains both “recording” and a “sorting” object. -.. code:: ipython3 +.. code:: ipython local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') recording, sorting_true = se.read_mearec(local_path) @@ -102,7 +102,7 @@ ground-truth information of the spiking activity of each unit. Let’s use the ``spikeinterface.widgets`` module to visualize the traces and the raster plots. -.. code:: ipython3 +.. code:: ipython w_ts = sw.plot_traces(recording, time_range=(0, 5)) w_rs = sw.plot_rasters(sorting_true, time_range=(0, 5)) @@ -118,7 +118,7 @@ and the raster plots. This is how you retrieve info from a ``BaseRecording``\ … -.. code:: ipython3 +.. code:: ipython channel_ids = recording.get_channel_ids() fs = recording.get_sampling_frequency() @@ -143,7 +143,7 @@ This is how you retrieve info from a ``BaseRecording``\ … …and a ``BaseSorting`` -.. code:: ipython3 +.. code:: ipython num_seg = recording.get_num_segments() unit_ids = sorting_true.get_unit_ids() @@ -173,7 +173,7 @@ any probe in the probeinterface collections can be downloaded and set to a ``Recording`` object. In this case, the MEArec dataset already handles a ``Probe`` and we don’t need to set it *manually*. -.. code:: ipython3 +.. code:: ipython probe = recording.get_probe() print(probe) @@ -200,7 +200,7 @@ All these preprocessing steps are “lazy”. The computation is done on demand when we call ``recording.get_traces(...)`` or when we save the object to disk. -.. code:: ipython3 +.. code:: ipython recording_cmr = recording recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000) @@ -224,7 +224,7 @@ Now you are ready to spike sort using the ``spikeinterface.sorters`` module! Let’s first check which sorters are implemented and which are installed -.. code:: ipython3 +.. code:: ipython print('Available sorters', ss.available_sorters()) print('Installed sorters', ss.installed_sorters()) @@ -241,7 +241,7 @@ machine. We can see we have HerdingSpikes and Tridesclous installed. Spike sorters come with a set of parameters that users can change. The available parameters are dictionaries and can be accessed with: -.. code:: ipython3 +.. code:: ipython print("Tridesclous params:") pprint(ss.get_default_sorter_params('tridesclous')) @@ -279,7 +279,7 @@ available parameters are dictionaries and can be accessed with: Let’s run ``tridesclous`` and change one of the parameter, say, the ``detect_threshold``: -.. code:: ipython3 +.. code:: ipython sorting_TDC = ss.run_sorter(sorter_name="tridesclous", recording=recording_preprocessed, detect_threshold=4) print(sorting_TDC) @@ -292,7 +292,7 @@ Let’s run ``tridesclous`` and change one of the parameter, say, the Alternatively we can pass full dictionary containing the parameters: -.. code:: ipython3 +.. code:: ipython other_params = ss.get_default_sorter_params('tridesclous') other_params['detect_threshold'] = 6 @@ -310,7 +310,7 @@ Alternatively we can pass full dictionary containing the parameters: Let’s run ``spykingcircus2`` as well, with default parameters: -.. code:: ipython3 +.. code:: ipython sorting_SC2 = ss.run_sorter(sorter_name="spykingcircus2", recording=recording_preprocessed) print(sorting_SC2) @@ -341,7 +341,7 @@ If a sorter is not installed locally, we can also avoid to install it and run it anyways, using a container (Docker or Singularity). For example, let’s run ``Kilosort2`` using Docker: -.. code:: ipython3 +.. code:: ipython sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, docker_image=True, verbose=True) @@ -370,7 +370,7 @@ extracts, their waveforms, and stores them to disk. These waveforms are helpful to compute the average waveform, or “template”, for each unit and then to compute, for example, quality metrics. -.. code:: ipython3 +.. code:: ipython we_TDC = si.extract_waveforms(recording_preprocessed, sorting_TDC, 'waveforms_folder', overwrite=True) print(we_TDC) @@ -399,7 +399,7 @@ compute spike amplitudes, PCA projections, unit locations, and more. Let’s compute some postprocessing information that will be needed later for computing quality metrics, exporting, and visualization: -.. code:: ipython3 +.. code:: ipython amplitudes = spost.compute_spike_amplitudes(we_TDC) unit_locations = spost.compute_unit_locations(we_TDC) @@ -411,7 +411,7 @@ for computing quality metrics, exporting, and visualization: All of this postprocessing functions are saved in the waveforms folder as extensions: -.. code:: ipython3 +.. code:: ipython print(we_TDC.get_available_extension_names()) @@ -424,7 +424,7 @@ as extensions: Importantly, waveform extractors (and all extensions) can be reloaded at later times: -.. code:: ipython3 +.. code:: ipython we_loaded = si.load_waveforms('waveforms_folder') print(we_loaded.get_available_extension_names()) @@ -439,7 +439,7 @@ Once we have computed all these postprocessing information, we can compute quality metrics (different quality metrics require different extensions - e.g., drift metrics resuire ``spike_locations``): -.. code:: ipython3 +.. code:: ipython qm_params = sqm.get_default_qm_params() pprint(qm_params) @@ -485,14 +485,14 @@ extensions - e.g., drift metrics resuire ``spike_locations``): Since the recording is very short, let’s change some parameters to accomodate the duration: -.. code:: ipython3 +.. code:: ipython qm_params["presence_ratio"]["bin_duration_s"] = 1 qm_params["amplitude_cutoff"]["num_histogram_bins"] = 5 qm_params["drift"]["interval_s"] = 2 qm_params["drift"]["min_spikes_per_interval"] = 2 -.. code:: ipython3 +.. code:: ipython qm = sqm.compute_quality_metrics(we_TDC, qm_params=qm_params) display(qm) @@ -522,7 +522,7 @@ We can export a sorting summary and quality metrics plot using the ``sortingview`` backend. This will generate shareble links for web-based visualization. -.. code:: ipython3 +.. code:: ipython w1 = sw.plot_quality_metrics(we_TDC, display=False, backend="sortingview") @@ -530,7 +530,7 @@ visualization. https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://901a11ba31ae9ab512a99bdf36a3874173249d87&label=SpikeInterface%20-%20Quality%20Metrics -.. code:: ipython3 +.. code:: ipython w2 = sw.plot_sorting_summary(we_TDC, display=False, curation=True, backend="sortingview") @@ -543,7 +543,7 @@ curation. In the example above, we manually merged two units (0, 4) and added accept labels (2, 6, 7). After applying our curation, we can click on the “Save as snapshot (sha://)” and copy the URI: -.. code:: ipython3 +.. code:: ipython uri = "sha1://68cb54a9aaed2303fb82dedbc302c853e818f1b6" @@ -562,7 +562,7 @@ Alternatively, we can export the data locally to Phy. `Phy `_ is a GUI for manual curation of the spike sorting output. To export to phy you can run: -.. code:: ipython3 +.. code:: ipython sexp.export_to_phy(we_TDC, 'phy_folder_for_TDC', verbose=True) @@ -581,7 +581,7 @@ After curating with Phy, the curated sorting can be reloaded to SpikeInterface. In this case, we exclude the units that have been labeled as “noise”: -.. code:: ipython3 +.. code:: ipython sorting_curated_phy = se.read_phy('phy_folder_for_TDC', exclude_cluster_groups=["noise"]) @@ -589,7 +589,7 @@ Quality metrics can be also used to automatically curate the spike sorting output. For example, you can select sorted units with a SNR above a certain threshold: -.. code:: ipython3 +.. code:: ipython keep_mask = (qm['snr'] > 10) & (qm['isi_violations_ratio'] < 0.01) print("Mask:", keep_mask.values) @@ -615,7 +615,7 @@ outputs. We can either: 3. compare the output of multiple sorters (Tridesclous, SpykingCircus2, Kilosort2) -.. code:: ipython3 +.. code:: ipython comp_gt = sc.compare_sorter_to_ground_truth(gt_sorting=sorting_true, tested_sorting=sorting_TDC) comp_pair = sc.compare_two_sorters(sorting1=sorting_TDC, sorting2=sorting_SC2) @@ -625,7 +625,7 @@ outputs. We can either: When comparing with a ground-truth sorting (1,), you can get the sorting performance and plot a confusion matrix -.. code:: ipython3 +.. code:: ipython print(comp_gt.get_performance()) w_conf = sw.plot_confusion_matrix(comp_gt) @@ -659,7 +659,7 @@ performance and plot a confusion matrix When comparing two sorters (2.), we can see the matching of units between sorters. Units which are not matched has -1 as unit id: -.. code:: ipython3 +.. code:: ipython comp_pair.hungarian_match_12 @@ -683,7 +683,7 @@ between sorters. Units which are not matched has -1 as unit id: or the reverse: -.. code:: ipython3 +.. code:: ipython comp_pair.hungarian_match_21 @@ -709,7 +709,7 @@ When comparing multiple sorters (3.), you can extract a ``BaseSorting`` object with units in agreement between sorters. You can also plot a graph showing how the units are matched between the sorters. -.. code:: ipython3 +.. code:: ipython sorting_agreement = comp_multi.get_agreement_sorting(minimum_agreement_count=2) diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index 7ff98a666b..5c4476187b 100644 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -1,4 +1,4 @@ -.. code:: ipython3 +.. code:: ipython %matplotlib inline %load_ext autoreload @@ -42,7 +42,7 @@ Neuropixels 1 and a Neuropixels 2 probe. Here we will use *dataset1* with *neuropixel1*. This dataset is the *“hello world”* for drift correction in the spike sorting community! -.. code:: ipython3 +.. code:: ipython from pathlib import Path import matplotlib.pyplot as plt @@ -52,12 +52,12 @@ Here we will use *dataset1* with *neuropixel1*. This dataset is the import spikeinterface.full as si -.. code:: ipython3 +.. code:: ipython base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick') dataset_folder = base_folder / 'dataset1/NP1' -.. code:: ipython3 +.. code:: ipython # read the file raw_rec = si.read_spikeglx(dataset_folder) @@ -77,7 +77,7 @@ We preprocess the recording with bandpass filter and a common median reference. Note, that it is better to not whiten the recording before motion estimation to get a better estimate of peak locations! -.. code:: ipython3 +.. code:: ipython def preprocess_chain(rec): rec = si.bandpass_filter(rec, freq_min=300., freq_max=6000.) @@ -85,7 +85,7 @@ motion estimation to get a better estimate of peak locations! return rec rec = preprocess_chain(raw_rec) -.. code:: ipython3 +.. code:: ipython job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True) @@ -101,7 +101,7 @@ parameters for each step. Here we also save the motion correction results into a folder to be able to load them later. -.. code:: ipython3 +.. code:: ipython # internally, we can explore a preset like this # every parameter can be overwritten at runtime @@ -143,13 +143,13 @@ to load them later. -.. code:: ipython3 +.. code:: ipython # lets try theses 3 presets some_presets = ('rigid_fast', 'kilosort_like', 'nonrigid_accurate') # some_presets = ('nonrigid_accurate', ) -.. code:: ipython3 +.. code:: ipython # compute motion with 3 presets for preset in some_presets: @@ -195,7 +195,7 @@ A few comments on the figures: (2000-3000um) experience some drift, but the lower part (0-1000um) is relatively stable. The method defined by this preset is able to capture this. -.. code:: ipython3 +.. code:: ipython for preset in some_presets: # load @@ -243,7 +243,7 @@ locations (:py:func:`correct_motion_on_peaks()`) Case 1 is used before running a spike sorter and the case 2 is used here to display the results. -.. code:: ipython3 +.. code:: ipython from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks @@ -303,7 +303,7 @@ run times Presets and related methods have differents accuracies but also computation speeds. It is good to have this in mind! -.. code:: ipython3 +.. code:: ipython run_times = [] for preset in some_presets: diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index b41e194466..ba40885421 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -39,7 +39,7 @@ The SpikeInterface implementation is a partial port of the low-level complexity References ---------- -.. automodule:: spikeinterface.toolkit.qualitymetrics.misc_metrics +.. automodule:: spikeinterface.qualitymetrics.misc_metrics .. autofunction:: compute_synchrony_metrics diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index 26f2365202..34ab3d1151 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -130,7 +130,7 @@ Parameters from all sorters can be retrieved with these functions: .. _containerizedsorters: Running sorters in Docker/Singularity Containers ------------------------------------------------ +------------------------------------------------ One of the biggest bottlenecks for users is installing spike sorting software. To alleviate this, we build and maintain containerized versions of several popular spike sorters on the diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index ff2a5b60c2..5f171b1e15 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -186,11 +186,11 @@ def correct_motion( Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: - * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks' - * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks' - * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks' - * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion' - * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion' + * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks` + * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks` + * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks` + * :py:func:`~spikeinterface.sortingcomponents.motion_estimation.estimate_motion` + * :py:func:`~spikeinterface.sortingcomponents.motion_interpolation.interpolate_motion` Possible presets: {} From 9c418c9bc19422b4ac42f6cbb5f96339f93f97fa Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 11:39:35 -0400 Subject: [PATCH 65/73] try just using spinexext --- doc/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 15cb65d46a..4bb7301564 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -67,8 +67,8 @@ 'numpydoc', "sphinx.ext.intersphinx", "sphinx.ext.extlinks", - "IPython.sphinxext.ipython_directive", - "IPython.sphinxext.ipython_console_highlighting" + "sphinxext.ipython_directive", + "sphinxext.ipython_console_highlighting" ] numpydoc_show_class_members = False From fff313c1517143fa87b270bc3b9abbb77d6243aa Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 11:44:16 -0400 Subject: [PATCH 66/73] typo fix --- doc/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 4bb7301564..610f13e9ed 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -67,8 +67,8 @@ 'numpydoc', "sphinx.ext.intersphinx", "sphinx.ext.extlinks", - "sphinxext.ipython_directive", - "sphinxext.ipython_console_highlighting" + "sphinx.ext.ipython_directive", + "sphinx.ext.ipython_console_highlighting" ] numpydoc_show_class_members = False From 777a07d3a538394d52a18a05662831a403ee35f9 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 13 Sep 2023 17:47:35 +0200 Subject: [PATCH 67/73] added to extractors rst (#1992) --- doc/modules/extractors.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/modules/extractors.rst b/doc/modules/extractors.rst index 0a0ad90ffa..5aed24ca41 100644 --- a/doc/modules/extractors.rst +++ b/doc/modules/extractors.rst @@ -129,10 +129,11 @@ For raw recording formats, we currently support: * **MCS RAW** :py:func:`~spikeinterface.extractors.read_mcsraw()` * **MEArec** :py:func:`~spikeinterface.extractors.read_mearec()` * **Mountainsort MDA** :py:func:`~spikeinterface.extractors.read_mda_recording()` +* **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Neurodata Without Borders** :py:func:`~spikeinterface.extractors.read_nwb_recording()` * **Neuroscope** :py:func:`~spikeinterface.coextractorsre.read_neuroscope_recording()` +* **Neuroexplorer** :py:func:`~spikeinterface.extractors.read_neuroexplorer()` * **NIX** :py:func:`~spikeinterface.extractors.read_nix()` -* **Neuralynx** :py:func:`~spikeinterface.extractors.read_neuralynx()` * **Open Ephys Legacy** :py:func:`~spikeinterface.extractors.read_openephys()` * **Open Ephys Binary** :py:func:`~spikeinterface.extractors.read_openephys()` * **Plexon** :py:func:`~spikeinterface.extractors.read_plexon()` From b569e50e72df1527966aa7ddafefc76dfcf1c053 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 12:00:45 -0400 Subject: [PATCH 68/73] add ipython to docs --- doc/conf.py | 4 ++-- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 610f13e9ed..15cb65d46a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -67,8 +67,8 @@ 'numpydoc', "sphinx.ext.intersphinx", "sphinx.ext.extlinks", - "sphinx.ext.ipython_directive", - "sphinx.ext.ipython_console_highlighting" + "IPython.sphinxext.ipython_directive", + "IPython.sphinxext.ipython_console_highlighting" ] numpydoc_show_class_members = False diff --git a/pyproject.toml b/pyproject.toml index 474cdc483f..51efe1f585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ docs = [ "sphinx_rtd_theme==1.0.0", "sphinx-gallery", "numpydoc", + "ipython", # for notebooks in the gallery "MEArec", # Use as an example From 6f26f55d67970a37715e3a01fb961e3fd89780a0 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:12:11 -0400 Subject: [PATCH 69/73] more document fixes, fix toctree --- doc/modules/motion_correction.rst | 2 +- doc/modules/qualitymetrics.rst | 1 + doc/modules/qualitymetrics/silhouette_score.rst | 11 ++++++++++- doc/modules/qualitymetrics/synchrony.rst | 4 +++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 62c0d6b8d4..afedc4f982 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -9,7 +9,7 @@ Overview Mechanical drift, often observed in recordings, is currently a major issue for spike sorting. This is especially striking with the new generation of high-density devices used for in-vivo electrophyisology such as the neuropixel electrodes. -The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_) +The first sorter that has introduced motion/drift correction as a prepossessing step was Kilosort2.5 (see [Steinmetz2021]_ [SteinmetzDataset]_) Long story short, the main idea is the same as the one used for non-rigid image registration, for example with calcium imaging. However, because with extracellular recording we do not have a proper image to use as a reference, the main idea diff --git a/doc/modules/qualitymetrics.rst b/doc/modules/qualitymetrics.rst index ee3234af6c..8c7c0a2cc3 100644 --- a/doc/modules/qualitymetrics.rst +++ b/doc/modules/qualitymetrics.rst @@ -38,6 +38,7 @@ For more details about each metric and it's availability and use within SpikeInt qualitymetrics/snr qualitymetrics/noise_cutoff qualitymetrics/silhouette_score + qualitymetrics/synchrony This code snippet shows how to compute quality metrics (with or without principal components) in SpikeInterface: diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index 275805c6a7..0ce4399710 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -1,3 +1,5 @@ +.. _silhouette_score : + Silhouette score (:code:`silhouette`, :code:`silhouette_full`) ============================================================== @@ -7,7 +9,7 @@ Calculation Gives the ratio between the cohesiveness of a cluster and its separation from other clusters. Values for silhouette score range from -1 to 1. -For the full method as proposed by [Rouseeuw]_, the pairwise distances between each point +For the full method as proposed by [Rousseeuw]_, the pairwise distances between each point and every other point :math:`a(i)` in a cluster :math:`C_i` are calculated and then iterating through every other cluster's distances between the points in :math:`C_i` and the points in :math:`C_j` are calculated. The cluster with the minimal mean distance is taken to be :math:`b(i)`. The @@ -48,6 +50,13 @@ To reduce complexity the default implementation in SpikeInterface is to use the This can be changes by switching the silhouette method to either 'full' (the Rousseeuw implementation) or ('simplified', 'full') for both methods when entering the qm_params parameter. +References +---------- + +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette + +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette_full + Literature ---------- diff --git a/doc/modules/qualitymetrics/synchrony.rst b/doc/modules/qualitymetrics/synchrony.rst index ba40885421..2f566bf8a7 100644 --- a/doc/modules/qualitymetrics/synchrony.rst +++ b/doc/modules/qualitymetrics/synchrony.rst @@ -1,3 +1,5 @@ +.. _synchrony: + Synchrony Metrics (:code:`synchrony`) ===================================== @@ -46,4 +48,4 @@ References Literature ---------- -Based on concepts described in Gruen_ +Based on concepts described in [Gruen]_ From e0accb6885b6dc473090b342e09f067174e16a6b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:43:51 -0400 Subject: [PATCH 70/73] more warning fixes --- doc/modules/qualitymetrics/silhouette_score.rst | 4 ++-- src/spikeinterface/postprocessing/principal_component.py | 4 ++-- src/spikeinterface/preprocessing/motion.py | 1 + src/spikeinterface/qualitymetrics/pca_metrics.py | 6 ++++++ src/spikeinterface/sortingcomponents/peak_detection.py | 5 +++-- src/spikeinterface/widgets/spikes_on_traces.py | 5 ++--- 6 files changed, 16 insertions(+), 9 deletions(-) diff --git a/doc/modules/qualitymetrics/silhouette_score.rst b/doc/modules/qualitymetrics/silhouette_score.rst index 0ce4399710..b924cdbf73 100644 --- a/doc/modules/qualitymetrics/silhouette_score.rst +++ b/doc/modules/qualitymetrics/silhouette_score.rst @@ -53,9 +53,9 @@ This can be changes by switching the silhouette method to either 'full' (the Rou References ---------- -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.simplified_silhouette_score -.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette_full +.. autofunction:: spikeinterface.qualitymetrics.pca_metrics.silhouette_score Literature diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 991d79506e..100eedc1d1 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -694,11 +694,10 @@ def compute_principal_components( If True and pc scores are already in the waveform extractor folders, pc scores are loaded and not recomputed. n_components: int Number of components fo PCA - default 5 - mode: str + mode: str, default: 'by_channel_local' - 'by_channel_local': a local PCA is fitted for each channel (projection by channel) - 'by_channel_global': a global PCA is fitted for all channels (projection by channel) - 'concatenated': channels are concatenated and a global PCA is fitted - default 'by_channel_local' sparsity: ChannelSparsity or None The sparsity to apply to waveforms. If waveform_extractor is already sparse, the default sparsity will be used - default None @@ -735,6 +734,7 @@ def compute_principal_components( >>> # run for all spikes in the SortingExtractor >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ + if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 5f171b1e15..cc2ee9d801 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -186,6 +186,7 @@ def correct_motion( Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: + * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks` * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks` * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks` diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index b7b267251d..2383622907 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -736,6 +736,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): """Calculates the simplified silhouette score for each cluster. The value ranges from -1 (bad clustering) to 1 (good clustering). The simplified silhoutte score utilizes the centroids for distance calculations rather than pairwise calculations. + Parameters ---------- all_pcs : 2d array @@ -744,10 +745,12 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): The cluster labels for all spikes. Must have length of number of spikes. this_unit_id : int The ID for the unit to calculate this metric for. + Returns ------- unit_silhouette_score : float Simplified Silhouette Score for this unit + References ------------ Based on simplified silhouette score suggested by [Hruschka]_ @@ -782,6 +785,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): """Calculates the silhouette score which is a marker of cluster quality ranging from -1 (bad clustering) to 1 (good clustering). Distances are all calculated as pairwise comparisons of all data points. + Parameters ---------- all_pcs : 2d array @@ -790,10 +794,12 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): The cluster labels for all spikes. Must have length of number of spikes. this_unit_id : int The ID for the unit to calculate this metric for. + Returns ------- unit_silhouette_score : float Silhouette Score for this unit + References ------------ Based on [Rousseeuw]_ diff --git a/src/spikeinterface/sortingcomponents/peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection.py index f3719b934b..bc52ea2c70 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection.py @@ -65,10 +65,9 @@ def detect_peaks( This avoid reading the recording multiple times. gather_mode: str How to gather the results: - * "memory": results are returned as in-memory numpy arrays - * "npy": results are stored to .npy files in `folder` + folder: str or Path If gather_mode is "npy", the folder where the files are created. names: list @@ -81,9 +80,11 @@ def detect_peaks( ------- peaks: array Detected peaks. + Notes ----- This peak detection ported from tridesclous into spikeinterface. + """ assert method in detect_peak_methods diff --git a/src/spikeinterface/widgets/spikes_on_traces.py b/src/spikeinterface/widgets/spikes_on_traces.py index e7bcff0832..ae036d1ba1 100644 --- a/src/spikeinterface/widgets/spikes_on_traces.py +++ b/src/spikeinterface/widgets/spikes_on_traces.py @@ -30,11 +30,10 @@ class SpikesOnTracesWidget(BaseWidget): sparsity : ChannelSparsity or None Optional ChannelSparsity to apply. If WaveformExtractor is already sparse, the argument is ignored, default None - unit_colors : dict or None + unit_colors : dict or None If given, a dictionary with unit ids as keys and colors as values, default None If None, then the get_unit_colors() is internally used. (matplotlib backend) - mode : str - Three possible modes, default 'auto': + mode : str in ('line', 'map', 'auto') default: 'auto' * 'line': classical for low channel count * 'map': for high channel count use color heat map * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) From 5a05753be0976eeb41c89a8a93bd8c6142624dc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 18:44:14 +0000 Subject: [PATCH 71/73] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/principal_component.py | 2 +- src/spikeinterface/preprocessing/motion.py | 2 +- src/spikeinterface/qualitymetrics/pca_metrics.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 100eedc1d1..233625e09e 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -734,7 +734,7 @@ def compute_principal_components( >>> # run for all spikes in the SortingExtractor >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ - + if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index cc2ee9d801..e2ef6e6794 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -186,7 +186,7 @@ def correct_motion( Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: - + * :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks` * :py:func:`~spikeinterface.sortingcomponents.peak_selection.select_peaks` * :py:func:`~spikeinterface.sortingcomponents.peak_localization.localize_peaks` diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2383622907..d1f7534e3c 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -799,7 +799,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): ------- unit_silhouette_score : float Silhouette Score for this unit - + References ------------ Based on [Rousseeuw]_ From 8f382701b36e497a8efb5d05bb840a4222212917 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:58:19 -0400 Subject: [PATCH 72/73] fix docstring typo --- src/spikeinterface/qualitymetrics/pca_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 2383622907..74ded44111 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -752,7 +752,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): Simplified Silhouette Score for this unit References - ------------ + ---------- Based on simplified silhouette score suggested by [Hruschka]_ """ @@ -801,7 +801,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): Silhouette Score for this unit References - ------------ + ---------- Based on [Rousseeuw]_ """ From d4cb08149a7c14595f763c4df93b975272c8802b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 Sep 2023 11:13:06 +0200 Subject: [PATCH 73/73] Fix last Sphinx warninggit diffgit diff --- src/spikeinterface/widgets/traces.py | 6 +----- src/spikeinterface/widgets/widget_list.py | 6 ++++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/traces.py b/src/spikeinterface/widgets/traces.py index 9a2ec4a215..e025f779c1 100644 --- a/src/spikeinterface/widgets/traces.py +++ b/src/spikeinterface/widgets/traces.py @@ -26,6 +26,7 @@ class TracesWidget(BaseWidget): List with start time and end time, default None mode: str Three possible modes, default 'auto': + * 'line': classical for low channel count * 'map': for high channel count use color heat map * 'auto': auto switch depending on the channel count ('line' if less than 64 channels, 'map' otherwise) @@ -51,11 +52,6 @@ class TracesWidget(BaseWidget): For 'map' mode and sortingview backend, seconds to render in each row, default 0.2 add_legend : bool If True adds legend to figures, default True - - Returns - ------- - W: TracesWidget - The output widget """ def __init__( diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index f3c640ff16..9c89b3981e 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -54,6 +54,12 @@ {backends} **backend_kwargs: kwargs {backend_kwargs} + + + Returns + ------- + w : BaseWidget + The output widget object. """ # backend_str = f" {list(wcls.possible_backends.keys())}" backend_str = f" {wcls.get_possible_backends()}"