From f498179431d23291f72bbea40bc6a95b65dc5913 Mon Sep 17 00:00:00 2001 From: Julia Sprenger Date: Thu, 22 Jun 2023 10:29:01 +0200 Subject: [PATCH 01/35] Add plexon2 recording, sorting and event support --- .../extractors/neoextractors/__init__.py | 7 +- .../extractors/neoextractors/plexon2.py | 102 ++++++++++++++++++ .../extractors/tests/test_neoextractors.py | 18 ++++ 3 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 src/spikeinterface/extractors/neoextractors/plexon2.py diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0d9da1960a..4c12017328 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -25,6 +25,8 @@ read_openephys_event, ) from .plexon import PlexonRecordingExtractor, PlexonSortingExtractor, read_plexon, read_plexon_sorting +from .plexon2 import (Plexon2SortingExtractor, Plexon2RecordingExtractor, Plexon2EventExtractor, + read_plexon2, read_plexon2_sorting, read_plexon2_event) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx @@ -49,12 +51,13 @@ OpenEphysBinaryRecordingExtractor, OpenEphysLegacyRecordingExtractor, PlexonRecordingExtractor, + Plexon2RecordingExtractor, Spike2RecordingExtractor, SpikeGadgetsRecordingExtractor, SpikeGLXRecordingExtractor, TdtRecordingExtractor, ] -neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor] +neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor, Plexon2SortingExtractor] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor] +neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py new file mode 100644 index 0000000000..5ccceac875 --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -0,0 +1,102 @@ +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import (NeoBaseRecordingExtractor, NeoBaseSortingExtractor, + NeoBaseEventExtractor) + + +class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading plexon pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + stream_id: str, optional + If there are several streams, specify the stream id you want to load. + stream_name: str, optional + If there are several streams, specify the stream name you want to load. + all_annotations: bool, default: False + Load exhaustively all annotations from neo. + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + ) + self._kwargs.update({"file_path": str(file_path)}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2SortingExtractor(NeoBaseSortingExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + handle_spike_frame_directly = True + name = "plexon2" + + def __init__(self, file_path): + from neo.rawio import Plexon2RawIO + + neo_kwargs = self.map_to_neo_kwargs(file_path) + neo_reader = Plexon2RawIO(**neo_kwargs) + neo_reader.parse_header() + NeoBaseSortingExtractor.__init__(self, **neo_kwargs) + self._kwargs.update({"file_path": str(file_path)}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2EventExtractor(NeoBaseEventExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + folder_path: str + + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, folder_path, block_index=None): + neo_kwargs = self.map_to_neo_kwargs(folder_path) + NeoBaseEventExtractor.__init__(self, block_index=block_index, **neo_kwargs) + + @classmethod + def map_to_neo_kwargs(cls, folder_path): + neo_kwargs = {"filename": str(folder_path)} + return neo_kwargs + + +read_plexon2 = define_function_from_class(source_class=Plexon2RecordingExtractor, name="read_plexon2") +read_plexon2_sorting = define_function_from_class(source_class=Plexon2SortingExtractor, name="read_plexon2_sorting") +read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name='read_plexon2_event') diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index d19574e094..a28752acdd 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -120,6 +120,12 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): }, ] +class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2EventExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor @@ -128,6 +134,12 @@ class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): "plexon/File_plexon_3.plx", ] +class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2RecordingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ] class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonSortingExtractor @@ -136,6 +148,12 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ("plexon/File_plexon_1.plx"), ] +class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2SortingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuralynxRecordingExtractor From 9f42895213a47ddb9158e4cccb48dcec1dea9549 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jun 2023 08:31:48 +0000 Subject: [PATCH 02/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../extractors/neoextractors/__init__.py | 17 ++++++++++++++--- .../extractors/neoextractors/plexon2.py | 5 ++--- .../extractors/tests/test_neoextractors.py | 6 ++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 4c12017328..3360b76147 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -25,8 +25,14 @@ read_openephys_event, ) from .plexon import PlexonRecordingExtractor, PlexonSortingExtractor, read_plexon, read_plexon_sorting -from .plexon2 import (Plexon2SortingExtractor, Plexon2RecordingExtractor, Plexon2EventExtractor, - read_plexon2, read_plexon2_sorting, read_plexon2_event) +from .plexon2 import ( + Plexon2SortingExtractor, + Plexon2RecordingExtractor, + Plexon2EventExtractor, + read_plexon2, + read_plexon2_sorting, + read_plexon2_event, +) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx @@ -58,6 +64,11 @@ TdtRecordingExtractor, ] -neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor, Plexon2SortingExtractor] +neo_sorting_extractors_list = [ + BlackrockSortingExtractor, + MEArecSortingExtractor, + NeuralynxSortingExtractor, + Plexon2SortingExtractor, +] neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 5ccceac875..c3869dbadc 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -1,7 +1,6 @@ from spikeinterface.core.core_tools import define_function_from_class -from .neobaseextractor import (NeoBaseRecordingExtractor, NeoBaseSortingExtractor, - NeoBaseEventExtractor) +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): @@ -99,4 +98,4 @@ def map_to_neo_kwargs(cls, folder_path): read_plexon2 = define_function_from_class(source_class=Plexon2RecordingExtractor, name="read_plexon2") read_plexon2_sorting = define_function_from_class(source_class=Plexon2SortingExtractor, name="read_plexon2_sorting") -read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name='read_plexon2_event') +read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name="read_plexon2_event") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index a28752acdd..b14bcc9cf8 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -120,6 +120,7 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): }, ] + class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor downloads = ["plexon"] @@ -127,6 +128,7 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ("plexon/4chDemoPL2.pl2"), ] + class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor downloads = ["plexon"] @@ -134,6 +136,7 @@ class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): "plexon/File_plexon_3.plx", ] + class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] @@ -141,6 +144,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), ] + class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonSortingExtractor downloads = ["plexon"] @@ -148,6 +152,7 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ("plexon/File_plexon_1.plx"), ] + class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] @@ -155,6 +160,7 @@ class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ("plexon/4chDemoPL2.pl2"), ] + class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuralynxRecordingExtractor downloads = ["neuralynx"] From a8e9924ac9d14dd7ec5f116112866846eac2e9e2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 10:42:07 +0200 Subject: [PATCH 03/35] Start waveforme xtarctor in one buffer --- src/spikeinterface/core/waveform_tools.py | 217 +++++++++++++++++++++- 1 file changed, 213 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a10c209f47..a68f8cfd5f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -257,8 +257,8 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _waveform_extractor_chunk - init_func = _init_worker_waveform_extractor + func = _worker_ditribute_buffers + init_func = _init_worker_ditribute_buffers init_args = ( recording, @@ -282,7 +282,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_waveform_extractor( +def _init_worker_ditribute_buffers( recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker @@ -328,7 +328,216 @@ def _init_worker_waveform_extractor( # used by ChunkRecordingExecutor -def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): +def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + unit_ids = worker_ctx["unit_ids"] + spikes = worker_ctx["spikes"] + nbefore = worker_ctx["nbefore"] + nafter = worker_ctx["nafter"] + return_scaled = worker_ctx["return_scaled"] + inds_by_unit = worker_ctx["inds_by_unit"] + sparsity_mask = worker_ctx["sparsity_mask"] + + seg_size = recording.get_num_samples(segment_index=segment_index) + + # take only spikes with the correct segment_index + # this is a slice so no copy!! + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + in_seg_spikes = spikes[s0:s1] + + # take only spikes in range [start_frame, end_frame] + # this is a slice so no copy!! + i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) + i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) + if i0 != i1: + # protect from spikes on border : spike_time<0 or spike_time>seg_size + # useful only when max_spikes_per_unit is not None + # waveform will not be extracted and a zeros will be left in the memmap file + while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): + i0 = i0 + 1 + while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): + i1 = i1 - 1 + + # slice in absolut in spikes vector + l0 = i0 + s0 + l1 = i1 + s0 + + if l1 > l0: + start = spikes[l0]["sample_index"] - nbefore + end = spikes[l1 - 1]["sample_index"] + nafter + + # load trace in memory + traces = recording.get_traces( + start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled + ) + + for unit_ind, unit_id in enumerate(unit_ids): + # find pos + inds = inds_by_unit[unit_id] + (in_chunk_pos,) = np.nonzero((inds >= l0) & (inds < l1)) + if in_chunk_pos.size == 0: + continue + + if worker_ctx["mode"] == "memmap": + # open file in demand (and also autoclose it after) + filename = worker_ctx["wfs_arrays_info"][unit_id] + wfs = np.load(str(filename), mmap_mode="r+") + elif worker_ctx["mode"] == "shared_memory": + wfs = worker_ctx["wfs_arrays"][unit_id] + + for pos in in_chunk_pos: + sample_index = spikes[inds[pos]]["sample_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + + if sparsity_mask is None: + wfs[pos, :, :] = wf + else: + wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] + + +def extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + folder=None, + dtype=None, + sparsity_mask=None, + copy=False, + **job_kwargs, +): + + nsamples = nbefore + nafter + + dtype = np.dtype(dtype) + if mode == "shared_memory": + assert folder is None + else: + folder = Path(folder) + + num_spikes = spike.size + if sparsity_mask is None: + num_chans = recording.get_num_channels() + else: + num_chans = np.sum(sparsity_mask[unit_ind, :]) + shape = (num_spikes, nsamples, num_chans) + + if mode == "memmap": + filename = str(folder / f"all_waveforms.npy") + wfs_array = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) + wf_array_info = filename + elif mode == "shared_memory": + if num_spikes == 0: + wfs_array = np.zeros(shape, dtype=dtype) + shm = None + shm_name = None + else: + wfs_array, shm = make_shared_array(shape, dtype) + shm_name = shm.name + wf_array_info = (shm, shm_name, dtype.str, shape) + else: + raise ValueError("allocate_waveforms_buffers bad mode") + + + job_kwargs = fix_job_kwargs(job_kwargs) + + inds_by_unit = {} + for unit_ind, unit_id in enumerate(unit_ids): + (inds,) = np.nonzero(spikes["unit_index"] == unit_ind) + inds_by_unit[unit_id] = inds + + if num_spikes > 0: + # and run + func = _worker_ditribute_one_buffer + init_func = _init_worker_ditribute_buffers + + init_args = ( + recording, + unit_ids, + spikes, + wf_array_info, + nbefore, + nafter, + return_scaled, + mode, + sparsity_mask, + ) + processor = ChunkRecordingExecutor( + recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs + ) + processor.run() + + + # if mode == "memmap": + # return wfs_arrays + # elif mode == "shared_memory": + # if copy: + # wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + # # release all sharedmem buffer + # for unit_id in unit_ids: + # shm = wfs_arrays_info[unit_id][0] + # if shm is not None: + # # empty array have None + # shm.unlink() + # return wfs_arrays + # else: + # return wfs_arrays, wfs_arrays_info + + + + +def _init_worker_ditribute_one_buffer( + recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask +): + + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["wf_array_info"] = wf_array_info + + if mode == "memmap": + filename = wf_array_info + wfs = np.load(str(filename), mmap_mode="r+") + + # in memmap mode we have the "too many open file" problem with linux + # memmap file will be open on demand and not globally per worker + worker_ctx["wf_array_info"] = wf_array_info + elif mode == "shared_memory": + from multiprocessing.shared_memory import SharedMemory + + wfs_arrays = {} + shms = {} + for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + if shm_name is None: + arr = np.zeros(shape=shape, dtype=dtype) + else: + shm = SharedMemory(shm_name) + arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + wfs_arrays[unit_id] = arr + # we need a reference to all sham otherwise we get segment fault!!! + shms[unit_id] = shm + worker_ctx["shms"] = shms + worker_ctx["wfs_arrays"] = wfs_arrays + + worker_ctx["unit_ids"] = unit_ids + worker_ctx["spikes"] = spikes + + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["inds_by_unit"] = inds_by_unit + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode + + return worker_ctx + + +# used by ChunkRecordingExecutor +def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] From bb4457505e22fb5074cf17060391e4e5ce91a80f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 14:37:15 +0200 Subject: [PATCH 04/35] wip waveform tools speedup --- .../core/tests/test_waveform_tools.py | 29 +++- src/spikeinterface/core/waveform_tools.py | 124 ++++++++---------- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index a896ff9c8b..457be5cba4 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,7 +7,7 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, + extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, ) # allocate_waveforms_buffers, distribute_waveforms_to_buffers @@ -64,8 +64,7 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir() - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) + wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -84,8 +83,27 @@ def test_waveform_tools(): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) + + wfs_array = extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + folder=wf_folder, + dtype=dtype, + sparsity_mask=None, + copy=False, + **job_kwargs, + ) + print(wfs_array.shape) + _check_all_wf_equal(list_wfs) + + # memory if platform.system() != "Windows": # shared memory on windows is buggy... @@ -125,9 +143,6 @@ def test_waveform_tools(): sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -144,5 +159,7 @@ def test_waveform_tools(): ) + + if __name__ == "__main__": test_waveform_tools() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a68f8cfd5f..a7c8493381 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -214,6 +214,7 @@ def distribute_waveforms_to_buffers( return_scaled, mode="memmap", sparsity_mask=None, + job_name=None, **job_kwargs, ): """ @@ -272,9 +273,9 @@ def distribute_waveforms_to_buffers( mode, sparsity_mask, ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name=f"extract waveforms {mode} multi buffer" + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -409,6 +410,7 @@ def extract_waveforms_to_unique_buffer( dtype=None, sparsity_mask=None, copy=False, + job_name=None, **job_kwargs, ): @@ -420,11 +422,11 @@ def extract_waveforms_to_unique_buffer( else: folder = Path(folder) - num_spikes = spike.size + num_spikes = spikes.size if sparsity_mask is None: num_chans = recording.get_num_channels() else: - num_chans = np.sum(sparsity_mask[unit_ind, :]) + num_chans = max(np.sum(sparsity_mask, axis=1)) shape = (num_spikes, nsamples, num_chans) if mode == "memmap": @@ -454,7 +456,7 @@ def extract_waveforms_to_unique_buffer( if num_spikes > 0: # and run func = _worker_ditribute_one_buffer - init_func = _init_worker_ditribute_buffers + init_func = _init_worker_ditribute_one_buffer init_args = ( recording, @@ -466,27 +468,27 @@ def extract_waveforms_to_unique_buffer( return_scaled, mode, sparsity_mask, + ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name = f"extract waveforms {mode} mono buffer" + + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() - # if mode == "memmap": - # return wfs_arrays - # elif mode == "shared_memory": - # if copy: - # wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} - # # release all sharedmem buffer - # for unit_id in unit_ids: - # shm = wfs_arrays_info[unit_id][0] - # if shm is not None: - # # empty array have None - # shm.unlink() - # return wfs_arrays - # else: - # return wfs_arrays, wfs_arrays_info + if mode == "memmap": + return wfs_array + elif mode == "shared_memory": + if copy: + wf_array_info = wf_array_info.copy() + if shm is not None: + # release all sharedmem buffer + # empty array have None + shm.unlink() + return wfs_array + else: + return wfs_array, wf_array_info @@ -501,35 +503,29 @@ def _init_worker_ditribute_one_buffer( if mode == "memmap": filename = wf_array_info - wfs = np.load(str(filename), mmap_mode="r+") - - # in memmap mode we have the "too many open file" problem with linux - # memmap file will be open on demand and not globally per worker - worker_ctx["wf_array_info"] = wf_array_info + wfs_array = np.load(str(filename), mmap_mode="r+") + worker_ctx["wfs_array"] = wfs_array elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - - wfs_arrays = {} - shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): - if shm_name is None: - arr = np.zeros(shape=shape, dtype=dtype) - else: - shm = SharedMemory(shm_name) - arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr - # we need a reference to all sham otherwise we get segment fault!!! - shms[unit_id] = shm - worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + shm, shm_name, dtype, shape = wf_array_info + shm = SharedMemory(shm_name) + wfs_array = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + worker_ctx["shm"] = shm + worker_ctx["wfs_array"] = wfs_array + + # prepare segment slices + segment_slices = [] + for segment_index in range(recording.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append((s0, s1)) + worker_ctx["segment_slices"] = segment_slices worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter worker_ctx["return_scaled"] = return_scaled - worker_ctx["inds_by_unit"] = inds_by_unit worker_ctx["sparsity_mask"] = sparsity_mask worker_ctx["mode"] = mode @@ -541,20 +537,18 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] + segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"] nafter = worker_ctx["nafter"] return_scaled = worker_ctx["return_scaled"] - inds_by_unit = worker_ctx["inds_by_unit"] sparsity_mask = worker_ctx["sparsity_mask"] + wfs_array = worker_ctx["wfs_array"] seg_size = recording.get_num_samples(segment_index=segment_index) - # take only spikes with the correct segment_index - # this is a slice so no copy!! - s0 = np.searchsorted(spikes["segment_index"], segment_index) - s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) - in_seg_spikes = spikes[s0:s1] + s0, s1 = segment_slices[segment_index] + in_seg_spikes = spikes[s0: s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! @@ -582,28 +576,18 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for unit_ind, unit_id in enumerate(unit_ids): - # find pos - inds = inds_by_unit[unit_id] - (in_chunk_pos,) = np.nonzero((inds >= l0) & (inds < l1)) - if in_chunk_pos.size == 0: - continue + for spike_ind in range(l0, l1): + sample_index = spikes[spike_ind]["sample_index"] + unit_index = spikes[spike_ind]["unit_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] - if worker_ctx["mode"] == "memmap": - # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] - wfs = np.load(str(filename), mmap_mode="r+") - elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] - - for pos in in_chunk_pos: - sample_index = spikes[inds[pos]]["sample_index"] - wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + if sparsity_mask is None: + wfs_array[spike_ind, :, :] = None + else: + mask = sparsity_mask[unit_index, :] + wf = wf[:, mask] + wfs_array[spike_ind, :, :wf.shape[1]] = wf - if sparsity_mask is None: - wfs[pos, :, :] = wf - else: - wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] def has_exceeding_spikes(recording, sorting): From 5539bdb0028b2bdcebb6b4e6d201f0579770f3e1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 17:02:02 +0200 Subject: [PATCH 05/35] extract_waveforms_to_unique_buffer is more or less OK --- .../core/tests/test_waveform_tools.py | 181 ++++++++---------- src/spikeinterface/core/waveform_tools.py | 59 +++--- 2 files changed, 115 insertions(+), 125 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 457be5cba4..ef75180898 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,8 +7,8 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, -) # allocate_waveforms_buffers, distribute_waveforms_to_buffers + extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, split_waveforms_by_units, +) if hasattr(pytest, "global_test_folder"): @@ -21,6 +21,10 @@ def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): for unit_id in wfs_arrays.keys(): + print() + print('*'*10) + print(wfs_arrays[unit_id].shape) + print(wfs_arrays0[unit_id].shape) assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) @@ -52,111 +56,86 @@ def test_waveform_tools(): unit_ids = sorting.unit_ids some_job_kwargs = [ - {}, {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}, {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] + some_modes = [ + {"mode" : "memmap"}, + ] + if platform.system() != "Windows": + # shared memory on windows is buggy... + some_modes.append({"mode" : "shared_memory", }) + + some_sparsity = [ + dict(sparsity_mask=None), + dict(sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + ] + # memmap mode - list_wfs = [] + list_wfs_dense = [] + list_wfs_sparse = [] for j, job_kwargs in enumerate(some_job_kwargs): - wf_folder = cache_folder / f"test_waveform_tools_{j}" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - - wfs_array = extract_waveforms_to_unique_buffer( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - print(wfs_array.shape) - - _check_all_wf_equal(list_wfs) - - - - # memory - if platform.system() != "Windows": - # shared memory on windows is buggy... - list_wfs = [] - for job_kwargs in some_job_kwargs: - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=dtype, - sparsity_mask=None, - copy=True, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - # to avoid warning we need to first destroy arrays then sharedmemm object - # del wfs_arrays - # del wfs_arrays_info - _check_all_wf_equal(list_wfs) - - # with sparsity - wf_folder = cache_folder / "test_waveform_tools_sparse" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") - job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=sparsity_mask, - copy=False, - **job_kwargs, - ) + for k, mode_kwargs in enumerate(some_modes): + for l, sparsity_kwargs in enumerate(some_sparsity): + + print() + print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + + if mode_kwargs["mode"] == "memmap": + wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + wf_folder.mkdir() + mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) + else: + mode_kwargs_ = mode_kwargs + + wfs_arrays = extract_waveforms_to_buffers( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + for unit_ind, unit_id in enumerate(unit_ids): + wf = wfs_arrays[unit_id] + assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) + + if sparsity_kwargs['sparsity_mask'] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + + all_waveforms = extract_waveforms_to_unique_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + wfs_arrays = split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs['sparsity_mask']) + if sparsity_kwargs['sparsity_mask'] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + _check_all_wf_equal(list_wfs_dense) + _check_all_wf_equal(list_wfs_sparse) + diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a7c8493381..1a2361ec97 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -189,7 +189,7 @@ def allocate_waveforms_buffers( wfs_arrays[unit_id] = arr wfs_arrays_info[unit_id] = filename elif mode == "shared_memory": - if n_spikes == 0: + if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) shm = None shm_name = None @@ -431,15 +431,15 @@ def extract_waveforms_to_unique_buffer( if mode == "memmap": filename = str(folder / f"all_waveforms.npy") - wfs_array = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) + all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) wf_array_info = filename elif mode == "shared_memory": - if num_spikes == 0: - wfs_array = np.zeros(shape, dtype=dtype) + if num_spikes == 0 or num_chans == 0: + all_waveforms = np.zeros(shape, dtype=dtype) shm = None shm_name = None else: - wfs_array, shm = make_shared_array(shape, dtype) + all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name wf_array_info = (shm, shm_name, dtype.str, shape) else: @@ -478,17 +478,16 @@ def extract_waveforms_to_unique_buffer( if mode == "memmap": - return wfs_array + return all_waveforms elif mode == "shared_memory": if copy: - wf_array_info = wf_array_info.copy() if shm is not None: # release all sharedmem buffer # empty array have None shm.unlink() - return wfs_array + return all_waveforms.copy() else: - return wfs_array, wf_array_info + return all_waveforms, wf_array_info @@ -500,18 +499,25 @@ def _init_worker_ditribute_one_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info + worker_ctx["unit_ids"] = unit_ids + worker_ctx["spikes"] = spikes + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode if mode == "memmap": filename = wf_array_info - wfs_array = np.load(str(filename), mmap_mode="r+") - worker_ctx["wfs_array"] = wfs_array + all_waveforms = np.load(str(filename), mmap_mode="r+") + worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory shm, shm_name, dtype, shape = wf_array_info shm = SharedMemory(shm_name) - wfs_array = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm - worker_ctx["wfs_array"] = wfs_array + worker_ctx["all_waveforms"] = all_waveforms # prepare segment slices segment_slices = [] @@ -521,14 +527,6 @@ def _init_worker_ditribute_one_buffer( segment_slices.append((s0, s1)) worker_ctx["segment_slices"] = segment_slices - worker_ctx["unit_ids"] = unit_ids - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode - return worker_ctx @@ -543,7 +541,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c nafter = worker_ctx["nafter"] return_scaled = worker_ctx["return_scaled"] sparsity_mask = worker_ctx["sparsity_mask"] - wfs_array = worker_ctx["wfs_array"] + all_waveforms = worker_ctx["all_waveforms"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -582,12 +580,25 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - wfs_array[spike_ind, :, :] = None + all_waveforms[spike_ind, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - wfs_array[spike_ind, :, :wf.shape[1]] = wf + all_waveforms[spike_ind, :, :wf.shape[1]] = wf + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): + waveform_by_units = {} + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes["unit_index"] == unit_index + if sparsity_mask is not None: + chan_mask = sparsity_mask[unit_index, :] + num_chans = np.sum(chan_mask) + waveform_by_units[unit_id] = all_waveforms[mask, :, :][:, :, :num_chans] + else: + waveform_by_units[unit_id] = all_waveforms[mask, :, :] + return waveform_by_units def has_exceeding_spikes(recording, sorting): From ec86e2987414dc060ff657531dd3be3b475bf49b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 7 Jul 2023 17:02:34 +0200 Subject: [PATCH 06/35] clean --- src/spikeinterface/core/tests/test_waveform_tools.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index ef75180898..c86aa6d5d7 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -21,10 +21,6 @@ def _check_all_wf_equal(list_wfs_arrays): wfs_arrays0 = list_wfs_arrays[0] for i, wfs_arrays in enumerate(list_wfs_arrays): for unit_id in wfs_arrays.keys(): - print() - print('*'*10) - print(wfs_arrays[unit_id].shape) - print(wfs_arrays0[unit_id].shape) assert np.array_equal(wfs_arrays[unit_id], wfs_arrays0[unit_id]) @@ -79,8 +75,8 @@ def test_waveform_tools(): for k, mode_kwargs in enumerate(some_modes): for l, sparsity_kwargs in enumerate(some_sparsity): - print() - print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + # print() + # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) if mode_kwargs["mode"] == "memmap": wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" From da729088eb539991dc315c66a634ac23cc0f136f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:03:10 +0000 Subject: [PATCH 07/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../core/tests/test_waveform_tools.py | 32 ++++++++++--------- src/spikeinterface/core/waveform_tools.py | 14 +++----- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index c86aa6d5d7..9a51a10ee2 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -7,7 +7,9 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( - extract_waveforms_to_buffers, extract_waveforms_to_unique_buffer, split_waveforms_by_units, + extract_waveforms_to_buffers, + extract_waveforms_to_unique_buffer, + split_waveforms_by_units, ) @@ -56,17 +58,20 @@ def test_waveform_tools(): {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] some_modes = [ - {"mode" : "memmap"}, + {"mode": "memmap"}, ] if platform.system() != "Windows": # shared memory on windows is buggy... - some_modes.append({"mode" : "shared_memory", }) + some_modes.append( + { + "mode": "shared_memory", + } + ) some_sparsity = [ dict(sparsity_mask=None), - dict(sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), ] - # memmap mode list_wfs_dense = [] @@ -74,7 +79,6 @@ def test_waveform_tools(): for j, job_kwargs in enumerate(some_job_kwargs): for k, mode_kwargs in enumerate(some_modes): for l, sparsity_kwargs in enumerate(some_sparsity): - # print() # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) @@ -86,7 +90,7 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) else: mode_kwargs_ = mode_kwargs - + wfs_arrays = extract_waveforms_to_buffers( recording, spikes, @@ -103,13 +107,12 @@ def test_waveform_tools(): for unit_ind, unit_id in enumerate(unit_ids): wf = wfs_arrays[unit_id] assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - - if sparsity_kwargs['sparsity_mask'] is None: + + if sparsity_kwargs["sparsity_mask"] is None: list_wfs_dense.append(wfs_arrays) else: list_wfs_sparse.append(wfs_arrays) - all_waveforms = extract_waveforms_to_unique_buffer( recording, spikes, @@ -123,8 +126,10 @@ def test_waveform_tools(): **mode_kwargs_, **job_kwargs, ) - wfs_arrays = split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs['sparsity_mask']) - if sparsity_kwargs['sparsity_mask'] is None: + wfs_arrays = split_waveforms_by_units( + unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"] + ) + if sparsity_kwargs["sparsity_mask"] is None: list_wfs_dense.append(wfs_arrays) else: list_wfs_sparse.append(wfs_arrays) @@ -133,8 +138,5 @@ def test_waveform_tools(): _check_all_wf_equal(list_wfs_sparse) - - - if __name__ == "__main__": test_waveform_tools() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 1a2361ec97..e6f7e944cc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -274,7 +274,7 @@ def distribute_waveforms_to_buffers( sparsity_mask, ) if job_name is None: - job_name=f"extract waveforms {mode} multi buffer" + job_name = f"extract waveforms {mode} multi buffer" processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -413,7 +413,6 @@ def extract_waveforms_to_unique_buffer( job_name=None, **job_kwargs, ): - nsamples = nbefore + nafter dtype = np.dtype(dtype) @@ -445,7 +444,6 @@ def extract_waveforms_to_unique_buffer( else: raise ValueError("allocate_waveforms_buffers bad mode") - job_kwargs = fix_job_kwargs(job_kwargs) inds_by_unit = {} @@ -468,7 +466,6 @@ def extract_waveforms_to_unique_buffer( return_scaled, mode, sparsity_mask, - ) if job_name is None: job_name = f"extract waveforms {mode} mono buffer" @@ -476,7 +473,6 @@ def extract_waveforms_to_unique_buffer( processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() - if mode == "memmap": return all_waveforms elif mode == "shared_memory": @@ -490,12 +486,9 @@ def extract_waveforms_to_unique_buffer( return all_waveforms, wf_array_info - - def _init_worker_ditribute_one_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): - worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info @@ -513,6 +506,7 @@ def _init_worker_ditribute_one_buffer( worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory + shm, shm_name, dtype, shape = wf_array_info shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) @@ -546,7 +540,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c seg_size = recording.get_num_samples(segment_index=segment_index) s0, s1 = segment_slices[segment_index] - in_seg_spikes = spikes[s0: s1] + in_seg_spikes = spikes[s0:s1] # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! @@ -584,7 +578,7 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, :wf.shape[1]] = wf + all_waveforms[spike_ind, :, : wf.shape[1]] = wf def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): From 45036bf72c84376edeebb531f4feebeb8f0255e2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 10 Jul 2023 14:30:34 +0200 Subject: [PATCH 08/35] wip waveforms tools single buffer --- .../core/tests/test_waveform_tools.py | 4 +- src/spikeinterface/core/waveform_tools.py | 173 +++++++++++++----- 2 files changed, 131 insertions(+), 46 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 9a51a10ee2..fb65e87458 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -8,7 +8,7 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, - extract_waveforms_to_unique_buffer, + extract_waveforms_to_single_buffer, split_waveforms_by_units, ) @@ -113,7 +113,7 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) - all_waveforms = extract_waveforms_to_unique_buffer( + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, unit_ids, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index e6f7e944cc..252ea68738 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -36,7 +36,7 @@ def extract_waveforms_to_buffers( Same as calling allocate_waveforms_buffers() and then distribute_waveforms_to_buffers(). - Important note: for the "shared_memory" mode wfs_arrays_info contains reference to + Important note: for the "shared_memory" mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. To avoid this a copy to non shared memmory can be perform at the end. @@ -66,17 +66,17 @@ def extract_waveforms_to_buffers( If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then wfs_arrays_info is also return. Please keep in mind that wfs_arrays_info - need to be referenced as long as wfs_arrays will be used otherwise it will be very hard to debug. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. Also when copy=False the SharedMemory will need to be unlink manually {} Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep) - wfs_arrays_info: dict of info + arrays_info: dict of info Optionally return in case of shared_memory if copy=False. Dictionary to "construct" array in workers process (memmap file or sharemem info) """ @@ -89,7 +89,7 @@ def extract_waveforms_to_buffers( dtype = "float32" dtype = np.dtype(dtype) - wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers( + waveforms_by_units, arrays_info = allocate_waveforms_buffers( recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask ) @@ -97,7 +97,7 @@ def extract_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -107,19 +107,19 @@ def extract_waveforms_to_buffers( ) if mode == "memmap": - return wfs_arrays + return waveforms_by_units elif mode == "shared_memory": if copy: - wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + waveforms_by_units = {unit_id: arr.copy() for unit_id, arr in waveforms_by_units.items()} # release all sharedmem buffer for unit_id in unit_ids: - shm = wfs_arrays_info[unit_id][0] + shm = arrays_info[unit_id][0] if shm is not None: # empty array have None shm.unlink() - return wfs_arrays + return waveforms_by_units else: - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info extract_waveforms_to_buffers.__doc__ = extract_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) @@ -131,7 +131,7 @@ def allocate_waveforms_buffers( """ Allocate memmap or shared memory buffers before snippet extraction. - Important note: for the shared memory mode wfs_arrays_info contains reference to + Important note: for the shared memory mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -158,9 +158,9 @@ def allocate_waveforms_buffers( Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep - wfs_arrays_info: dict of info + arrays_info: dict of info Dictionary to "construct" array in workers process (memmap file or sharemem) """ @@ -173,8 +173,8 @@ def allocate_waveforms_buffers( folder = Path(folder) # prepare buffers - wfs_arrays = {} - wfs_arrays_info = {} + waveforms_by_units = {} + arrays_info = {} for unit_ind, unit_id in enumerate(unit_ids): n_spikes = np.sum(spikes["unit_index"] == unit_ind) if sparsity_mask is None: @@ -186,8 +186,8 @@ def allocate_waveforms_buffers( if mode == "memmap": filename = str(folder / f"waveforms_{unit_id}.npy") arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = filename + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = filename elif mode == "shared_memory": if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) @@ -196,19 +196,19 @@ def allocate_waveforms_buffers( else: arr, shm = make_shared_array(shape, dtype) shm_name = shm.name - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) else: raise ValueError("allocate_waveforms_buffers bad mode") - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info def distribute_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -222,7 +222,7 @@ def distribute_waveforms_to_buffers( Buffers must be pre-allocated with the `allocate_waveforms_buffers()` function. - Important note, for "shared_memory" mode wfs_arrays_info contain reference to + Important note, for "shared_memory" mode arrays_info contain reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -234,7 +234,7 @@ def distribute_waveforms_to_buffers( This vector can be spikes = Sorting.to_spike_vector() unit_ids: list ot numpy List of unit_ids - wfs_arrays_info: dict + arrays_info: dict Dictionary to "construct" array in workers process (memmap file or sharemem) nbefore: int N samples before spike @@ -265,7 +265,7 @@ def distribute_waveforms_to_buffers( recording, unit_ids, spikes, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -284,7 +284,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor def _init_worker_ditribute_buffers( - recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask + recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker worker_ctx = {} @@ -297,23 +297,23 @@ def _init_worker_ditribute_buffers( if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["wfs_arrays_info"] = wfs_arrays_info + worker_ctx["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - wfs_arrays = {} + waveforms_by_units = {} shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + for unit_id, (shm, shm_name, dtype, shape) in arrays_info.items(): if shm_name is None: arr = np.zeros(shape=shape, dtype=dtype) else: shm = SharedMemory(shm_name) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr + waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + worker_ctx["waveforms_by_units"] = waveforms_by_units worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes @@ -383,10 +383,10 @@ def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx) if worker_ctx["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] + filename = worker_ctx["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] + wfs = worker_ctx["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -398,7 +398,7 @@ def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx) wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] -def extract_waveforms_to_unique_buffer( +def extract_waveforms_to_single_buffer( recording, spikes, unit_ids, @@ -413,6 +413,57 @@ def extract_waveforms_to_unique_buffer( job_name=None, **job_kwargs, ): + """ + Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. + + Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channel are not aligned across + units. + + Important note: for the "shared_memory" mode wf_array_info contains reference to + the shared memmory buffer, this variable must be reference as long as arrays as used. + And this variable is also returned. + To avoid this a copy to non shared memmory can be perform at the end. + + Parameters + ---------- + recording: recording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = Sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + N samples before spike + nafter: int + N samples after spike + mode: str + Mode to use ('memmap' | 'shared_memory') + return_scaled: bool + Scale traces before exporting to buffer or not. + folder: str or path + In case of memmap mode, folder to save npy files + dtype: numpy.dtype + dtype for waveforms buffer + sparsity_mask: None or array of bool + If not None shape must be must be (len(unit_ids), len(channel_ids)) + copy: bool + If True (default), the output shared memory object is copied to a numpy standard array. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually + {} + + Returns + ------- + all_waveforms: numpy array + Single array with shape (nump_spikes, num_samples, num_channels) + + wf_array_info: dict of info + Optionally return in case of shared_memory if copy=False. + Dictionary to "construct" array in workers process (memmap file or sharemem info) + """ nsamples = nbefore + nafter dtype = np.dtype(dtype) @@ -453,8 +504,8 @@ def extract_waveforms_to_unique_buffer( if num_spikes > 0: # and run - func = _worker_ditribute_one_buffer - init_func = _init_worker_ditribute_one_buffer + func = _worker_ditribute_single_buffer + init_func = _init_worker_ditribute_single_buffer init_args = ( recording, @@ -486,7 +537,7 @@ def extract_waveforms_to_unique_buffer( return all_waveforms, wf_array_info -def _init_worker_ditribute_one_buffer( +def _init_worker_ditribute_single_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} @@ -525,7 +576,7 @@ def _init_worker_ditribute_one_buffer( # used by ChunkRecordingExecutor -def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_ditribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -580,19 +631,53 @@ def _worker_ditribute_one_buffer(segment_index, start_frame, end_frame, worker_c wf = wf[:, mask] all_waveforms[spike_ind, :, : wf.shape[1]] = wf + if worker_ctx["mode"] == "memmap": + all_waveforms.flush() + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None, folder=None): + """ + Split a single buffer waveforms into waveforms by units (multi buffers or multi files). + + Parameters + ---------- + unit_ids: list or numpy array + List of unit ids + spikes: numpy array + The spike vector + all_waveforms : numpy array + Single buffer containing all waveforms + sparsity_mask : None or numpy array + Optionally the boolean sparsity mask + folder : None or str or Path + If a folde ri sgiven all -def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None): - waveform_by_units = {} + Returns + ------- + waveforms_by_units: dict of array + A dict of arrays. + In case of folder not None, this contain the memmap of the files. + """ + if folder is not None: + folder = Path(folder) + waveforms_by_units = {} for unit_index, unit_id in enumerate(unit_ids): mask = spikes["unit_index"] == unit_index if sparsity_mask is not None: chan_mask = sparsity_mask[unit_index, :] num_chans = np.sum(chan_mask) - waveform_by_units[unit_id] = all_waveforms[mask, :, :][:, :, :num_chans] + wfs = all_waveforms[mask, :, :][:, :, :num_chans] + else: + wfs = all_waveforms[mask, :, :] + + if folder is None: + waveforms_by_units[unit_id] = wfs else: - waveform_by_units[unit_id] = all_waveforms[mask, :, :] + np.save(folder / f"waveforms_{unit_id}.npy", wfs) + # this avoid keeping in memory all waveforms + waveforms_by_units[unit_id] = np.load(f"waveforms_{unit_id}.npy", mmap_mode="r") - return waveform_by_units + return waveforms_by_units def has_exceeding_spikes(recording, sorting): From 570a3a5fa8dc1b340da4ee0fc6baa16013213bdb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 26 Jul 2023 17:57:40 +0200 Subject: [PATCH 09/35] fix local tests --- src/spikeinterface/core/tests/test_waveform_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index fb65e87458..52d7472c92 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -86,7 +86,7 @@ def test_waveform_tools(): wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" if wf_folder.is_dir(): shutil.rmtree(wf_folder) - wf_folder.mkdir() + wf_folder.mkdir(parents=True) mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) else: mode_kwargs_ = mode_kwargs From f737f5a4b4b998321ff2999ec57915f1c1b6dc82 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 24 Aug 2023 15:14:55 +0200 Subject: [PATCH 10/35] wip collisions --- .../postprocessing/amplitude_scalings.py | 384 +++++++++++++++++- 1 file changed, 370 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3ebeafcfec..7539e4d0b7 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -22,8 +22,25 @@ def __init__(self, waveform_extractor): extremum_channel_inds=extremum_channel_inds, use_cache=False ) - def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): - params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + def _set_params( + self, + sparsity, + max_dense_channels, + ms_before, + ms_after, + handle_collisions, + max_consecutive_collisions, + delta_collision_ms, + ): + params = dict( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + max_consecutive_collisions=max_consecutive_collisions, + delta_collision_ms=delta_collision_ms, + ) return params def _select_extension_data(self, unit_ids): @@ -43,6 +60,12 @@ def _run(self, **job_kwargs): ms_before = self._params["ms_before"] ms_after = self._params["ms_after"] + # collisions + handle_collisions = self._params["handle_collisions"] + max_consecutive_collisions = self._params["max_consecutive_collisions"] + delta_collision_ms = self._params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) + return_scaled = we._params["return_scaled"] unit_ids = we.unit_ids @@ -67,6 +90,8 @@ def _run(self, **job_kwargs): assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) sparsity_inds = sparsity.unit_id_to_channel_indices + + # easier to use in chunk function as spikes use unit_index instead o id unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} all_templates = we.get_all_templates() @@ -93,6 +118,9 @@ def _run(self, **job_kwargs): cut_out_before, cut_out_after, return_scaled, + handle_collisions, + max_consecutive_collisions, + delta_collision_samples, ) processor = ChunkRecordingExecutor( recording, @@ -154,6 +182,9 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, + handle_collisions=False, + max_consecutive_collisions=3, + delta_collision_ms=2, load_if_exists=False, outputs="concatenated", **job_kwargs, @@ -165,22 +196,29 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity + sparsity: ChannelSparsity, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - By default None max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, optional + ms_before : float, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used, by default None - ms_after : float, optional + If None, the WaveformExtractor ms_before is used. + ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used, by default None + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: False + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + max_consecutive_collisions: int, default: 3 + The maximum number of consecutive collisions to handle on each side of a spike. + delta_collision_ms: float, default: 2 + The maximum time difference in ms between two spikes to be considered as colliding. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str + outputs: str, default: 'concatenated' How the output should be returned: - 'concatenated' - 'by_unit' @@ -197,7 +235,15 @@ def compute_amplitude_scalings( sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) else: sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + sac.set_params( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + max_consecutive_collisions=max_consecutive_collisions, + delta_collision_ms=delta_collision_ms, + ) sac.run(**job_kwargs) amps = sac.get_data(outputs=outputs) @@ -218,6 +264,9 @@ def _init_worker_amplitude_scalings( cut_out_before, cut_out_after, return_scaled, + handle_collisions, + max_consecutive_collisions, + delta_collision_samples, ): # create a local dict per worker worker_ctx = {} @@ -229,9 +278,18 @@ def _init_worker_amplitude_scalings( worker_ctx["nafter"] = nafter worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["margin"] = max(nbefore, nafter) worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["handle_collisions"] = handle_collisions + worker_ctx["max_consecutive_collisions"] = max_consecutive_collisions + worker_ctx["delta_collision_samples"] = delta_collision_samples + + if not handle_collisions: + worker_ctx["margin"] = max(nbefore, nafter) + else: + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = int(max_consecutive_collisions * delta_collision_samples) + worker_ctx["margin"] = max(margin_waveforms, max_margin_collisions) return worker_ctx @@ -250,6 +308,9 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after = worker_ctx["cut_out_after"] margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] + handle_collisions = worker_ctx["handle_collisions"] + max_consecutive_collisions = worker_ctx["max_consecutive_collisions"] + delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] @@ -272,8 +333,24 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) offsets = recording.get_property("offset_to_uV") traces_with_margin = traces_with_margin.astype("float32") * gains + offsets + # set colliding spikes apart (if needed) + if handle_collisions: + overlapping_mask = _find_overlapping_mask( + local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices + ) + overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] + print( + f"Found {len(overlapping_spike_indices)} overlapping spikes in segment {segment_index}! - chunk {start_frame} - {end_frame}" + ) + else: + overlapping_spike_indices = np.array([], dtype=int) + # get all waveforms - for spike in local_spikes: + scalings = np.zeros(len(local_spikes), dtype=float) + for spike_index, spike in enumerate(local_spikes): + if spike_index in overlapping_spike_indices: + # we deal with overlapping spikes later + continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] sparse_indices = unit_inds_to_channel_indices[unit_index] @@ -294,7 +371,286 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) local_waveforms.append(local_waveform) templates.append(template) linregress_res = linregress(template.flatten(), local_waveform.flatten()) - scalings.append(linregress_res[0]) - scalings = np.array(scalings) + scalings[spike_index] = linregress_res[0] + + # deal with collisions + if len(overlapping_spike_indices) > 0: + for overlapping in overlapping_mask: + spike_index = overlapping[max_consecutive_collisions] + overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] + scaled_amps = _fit_collision( + overlapping_spikes, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, + ) + # get the right amplitude scaling + scalings[spike_index] = scaled_amps[np.where(overlapping >= 0)[0] == max_consecutive_collisions] return (scalings,) + + +### Collision handling ### +def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): + """ + Returns True if the unit indices i and j are overlapping, False otherwise + + Parameters + ---------- + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + i: int + The first unit index + j: int + The second unit index + + Returns + ------- + bool + True if the unit indices i and j are overlapping, False otherwise + """ + if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + return True + else: + return False + + +def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): + """ + Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape + (n_spikes, 2 * max_consecutive_spikes + 1). + + Parameters + ---------- + spikes: np.array + An array of spikes + max_consecutive_spikes: int + The maximum number of consecutive spikes to consider + delta_overlap_samples: int + The maximum number of samples between two spikes to consider them as overlapping + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + + Returns + ------- + overlapping_mask: np.array + A boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1) where the central column (max_consecutive_spikes) + is the current spike index, while the other columns are the indices of the overlapping spikes. The first + max_consecutive_spikes columns are the pre-overlapping spikes, while the last max_consecutive_spikes columns are + the post-overlapping spikes. + """ + + # overlapping_mask is a matrix of shape (n_spikes, 2 * max_consecutive_spikes + 1) + # the central column (max_consecutive_spikes) is the current spike index, while the other columns are the + # indices of the overlapping spikes. The first max_consecutive_spikes columns are the pre-overlapping spikes, + # while the last max_consecutive_spikes columns are the post-overlapping spikes + # Rows with all -1 are non-colliding spikes and are removed later + overlapping_mask_full = -1 * np.ones((len(spikes), 2 * max_consecutive_spikes + 1), dtype=int) + overlapping_mask_full[:, max_consecutive_spikes] = np.arange(len(spikes)) + + for i, spike in enumerate(spikes): + # find the possible spikes per and post within max_consecutive_spikes * delta_overlap_samples + consecutive_window_pre = np.searchsorted( + spikes["sample_index"], + spike["sample_index"] - max_consecutive_spikes * delta_overlap_samples, + ) + consecutive_window_post = np.searchsorted( + spikes["sample_index"], + spike["sample_index"] + max_consecutive_spikes * delta_overlap_samples, + ) + pre_possible_consecutive_spikes = spikes[consecutive_window_pre:i][::-1] + post_possible_consecutive_spikes = spikes[i + 1 : consecutive_window_post] + + # here we fill in the overlapping information by consecutively looping through the possible consecutive spikes + # and checking the spatial overlap and the delay with the previous overlapping spike + # pre and post are hanlded separately. Note that the pre-spikes are already sorted backwards + + # overlap_rank keeps track of the rank of consecutive collisions (i.e., rank 0 is the first, rank 1 is the second, etc.) + # this is needed because we are just considering spikes with spatial overlap, while the possible consecutive spikes + # only looked at the temporal overlap + overlap_rank = 0 + if len(pre_possible_consecutive_spikes) > 0: + for c_pre, spike_consecutive_pre in enumerate(pre_possible_consecutive_spikes[::-1]): + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_pre["unit_index"] + ): + if ( + spikes[overlapping_mask_full[i, max_consecutive_spikes - overlap_rank]]["sample_index"] + - spike_consecutive_pre["sample_index"] + < delta_overlap_samples + ): + overlapping_mask_full[i, max_consecutive_spikes - overlap_rank - 1] = i - 1 - c_pre + overlap_rank += 1 + else: + break + # if overlap_rank > 1: + # print(f"\tHigher order pre-overlap for spike {i}!") + + overlap_rank = 0 + if len(post_possible_consecutive_spikes) > 0: + for c_post, spike_consecutive_post in enumerate(post_possible_consecutive_spikes): + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_post["unit_index"] + ): + if ( + spike_consecutive_post["sample_index"] + - spikes[overlapping_mask_full[i, max_consecutive_spikes + overlap_rank]]["sample_index"] + < delta_overlap_samples + ): + overlapping_mask_full[i, max_consecutive_spikes + overlap_rank + 1] = i + 1 + c_post + overlap_rank += 1 + else: + break + # if overlap_rank > 1: + # print(f"\tHigher order post-overlap for spike {i}!") + + # in case no collisions were found, we set the central column to -1 so that we can easily identify the non-colliding spikes + if np.sum(overlapping_mask_full[i] != -1) == 1: + overlapping_mask_full[i, max_consecutive_spikes] = -1 + + # only return rows with collisions + overlapping_inds = [] + for i, overlapping in enumerate(overlapping_mask_full): + if np.any(overlapping >= 0): + overlapping_inds.append(i) + overlapping_mask = overlapping_mask_full[overlapping_inds] + + return overlapping_mask + + +def _fit_collision( + overlapping_spikes, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, +): + """ """ + from sklearn.linear_model import LinearRegression + + sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left + sample_last_centered = overlapping_spikes[-1]["sample_index"] - start_frame - left + + # construct sparsity as union between units' sparsity + sparse_indices = np.array([], dtype="int") + for spike in overlapping_spikes: + sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + + local_waveform_start = max(0, sample_first_centered - cut_out_before) + local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) + local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + + y = local_waveform.T.flatten() + X = np.zeros((len(y), len(overlapping_spikes))) + for i, spike in enumerate(overlapping_spikes): + full_template = np.zeros_like(local_waveform) + # center wrt cutout traces + sample_centered = spike["sample_index"] - local_waveform_start + template = all_templates[spike["unit_index"]][:, sparse_indices] + template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] + # deal with borders + if sample_centered - cut_out_before < 0: + full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > end_frame + right: + full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + else: + full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut + X[:, i] = full_template.T.flatten() + + reg = LinearRegression().fit(X, y) + amps = reg.coef_ + return amps + + +# TODO: fix this! +# def plot_overlapping_spikes(we, overlap, +# spikes, cut_out_samples=100, +# max_consecutive_spikes=3, +# sparsity=None, +# fitted_amps=None): +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) +# spike_index = overlap[max_consecutive_spikes] +# overlap_indices = overlap[overlap != -1] +# overlapping_spikes = spikes[overlap_indices] + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = spikes[spike_index]["sample_index"] +# max_delta = np.max([np.abs(center_spike - overlapping_spikes[0]["sample_index"]), +# np.abs(center_spike - overlapping_spikes[-1]["sample_index"])]) +# sf = center_spike - max_delta - cut_out_samples +# ef = center_spike + max_delta + cut_out_samples +# tr_overlap = recording.get_traces(start_frame=sf, +# end_frame=ef, +# channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) +# fig, ax = plt.subplots() +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for spike in overlapping_spikes: +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline(spike["sample_index"] / recording.sampling_frequency * 1000, +# color=f"C{spike['unit_index']}", label=label) + +# if fitted_amps is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# template_scaled = fitted_amps[overlap_indices[i]] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf: sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): + +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", +# ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike}") +# return tr_overlap, ax From b9391a69c26f027e40da7cf0c3b7cffbf68b2d5e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 28 Aug 2023 09:55:29 +0200 Subject: [PATCH 11/35] wip collisions --- .../postprocessing/amplitude_scalings.py | 301 ++++++++++++------ 1 file changed, 203 insertions(+), 98 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 7539e4d0b7..d367ef4f22 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -21,6 +21,7 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) + self.overlapping_mask = None def _set_params( self, @@ -132,8 +133,30 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings,) = zip(*out) + (amp_scalings, overlapping_mask) = zip(*out) amp_scalings = np.concatenate(amp_scalings) + if handle_collisions > 0: + from ..core.job_tools import divide_recording_into_chunks + + overlapping_mask_corrected = [] + all_chunks = divide_recording_into_chunks(processor.recording, processor.chunk_size) + num_spikes_so_far = 0 + for i, overlapping in enumerate(overlapping_mask): + if i == 0: + continue + segment_index = all_chunks[i - 1][0] + spikes_in_segment = self.spikes[segment_slices[segment_index]] + i0 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][1]) + i1 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][2]) + num_spikes_so_far += i1 - i0 + overlapping_corrected = overlapping.copy() + overlapping_corrected[overlapping_corrected >= 0] += num_spikes_so_far + overlapping_mask_corrected.append(overlapping_corrected) + overlapping_mask = np.concatenate(overlapping_mask_corrected) + print(f"Found {len(overlapping_mask)} overlapping spikes") + self.overlapping_mask = overlapping_mask + else: + overlapping_mask = np.concatenate(overlapping_mask) self._extension_data[f"amplitude_scalings"] = amp_scalings @@ -314,13 +337,10 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) spikes_in_segment = spikes[segment_slices[segment_index]] + # TODO: handle spikes in margin! i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) - local_waveforms = [] - templates = [] - scalings = [] - if i0 != i1: local_spikes = spikes_in_segment[i0:i1] traces_with_margin, left, right = get_chunk_with_margin( @@ -335,13 +355,10 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: - overlapping_mask = _find_overlapping_mask( + overlapping_mask = find_overlapping_mask( local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices ) overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] - print( - f"Found {len(overlapping_spike_indices)} overlapping spikes in segment {segment_index}! - chunk {start_frame} - {end_frame}" - ) else: overlapping_spike_indices = np.array([], dtype=int) @@ -368,17 +385,18 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape - local_waveforms.append(local_waveform) - templates.append(template) + linregress_res = linregress(template.flatten(), local_waveform.flatten()) scalings[spike_index] = linregress_res[0] # deal with collisions if len(overlapping_spike_indices) > 0: for overlapping in overlapping_mask: + # the current spike is the one at the 'max_consecutive_collisions' position spike_index = overlapping[max_consecutive_collisions] overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] - scaled_amps = _fit_collision( + current_spike_index_within_overlapping = np.where(overlapping >= 0)[0] == max_consecutive_collisions + scaled_amps = fit_collision( overlapping_spikes, traces_with_margin, start_frame, @@ -392,9 +410,12 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after, ) # get the right amplitude scaling - scalings[spike_index] = scaled_amps[np.where(overlapping >= 0)[0] == max_consecutive_collisions] + scalings[spike_index] = scaled_amps[current_spike_index_within_overlapping] + else: + scalings = np.array([]) + overlapping_mask = np.array([], shape=(0, max_consecutive_collisions + 1)) - return (scalings,) + return (scalings, overlapping_mask) ### Collision handling ### @@ -422,7 +443,7 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): return False -def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): +def find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): """ Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1). @@ -525,7 +546,7 @@ def _find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples return overlapping_mask -def _fit_collision( +def fit_collision( overlapping_spikes, traces_with_margin, start_frame, @@ -537,8 +558,41 @@ def _fit_collision( unit_inds_to_channel_indices, cut_out_before, cut_out_after, + debug=True, ): - """ """ + """ + Compute the best fit for a collision between a spike and its overlapping spikes. + + Parameters + ---------- + overlapping_spikes: np.ndarray + A numpy array of shape (n_overlapping_spikes, ) containing the overlapping spikes (spike_dtype). + traces_with_margin: np.ndarray + A numpy array of shape (n_samples, n_channels) containing the traces with a margin. + start_frame: int + The start frame of the chunk for traces_with_margin. + end_frame: int + The end frame of the chunk for traces_with_margin. + left: int + The left margin of the chunk for traces_with_margin. + right: int + The right margin of the chunk for traces_with_margin. + nbefore: int + The number of samples before the spike to consider for the fit. + all_templates: np.ndarray + A numpy array of shape (n_units, n_samples, n_channels) containing the templates. + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices. + cut_out_before: int + The number of samples to cut out before the spike. + cut_out_after: int + The number of samples to cut out after the spike. + + Returns + ------- + np.ndarray + The fitted scaling factors for the overlapping spikes. + """ from sklearn.linear_model import LinearRegression sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left @@ -550,6 +604,7 @@ def _fit_collision( sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + # TODO: check alignment!!! local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] @@ -559,7 +614,7 @@ def _fit_collision( for i, spike in enumerate(overlapping_spikes): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - local_waveform_start + sample_centered = spike["sample_index"] - start_frame - left - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -571,86 +626,136 @@ def _fit_collision( full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() + if debug: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + max_tr = np.max(np.abs(local_waveform)) + + _ = ax.plot(y, color="k") + + for i, spike in enumerate(overlapping_spikes): + _ = ax.plot(X[:, i], color=f"C{i}", alpha=0.5) + plt.show() + reg = LinearRegression().fit(X, y) - amps = reg.coef_ - return amps + scalings = reg.coef_ + return scalings + + +def plot_collisions(we, sparsity=None, num_collisions=None): + """ + Plot the fitting of collision spikes. + + Parameters + ---------- + we : WaveformExtractor + The WaveformExtractor object. + sparsity : ChannelSparsity, default=None + The ChannelSparsity. If None, only main channels are plotted. + num_collisions : int, default=None + Number of collisions to plot. If None, all collisions are plotted. + """ + assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" + sac = we.load_extension("amplitude_scalings") + handle_collisions = sac._params["handle_collisions"] + assert handle_collisions, "Amplitude scalings was run without handling collisions!" + scalings = sac.get_data() + + overlapping_mask = sac.overlapping_mask + num_collisions = num_collisions or len(overlapping_mask) + spikes = sac.spikes + max_consecutive_collisions = sac._params["max_consecutive_collisions"] + + for i in range(num_collisions): + ax = _plot_one_collision( + we, overlapping_mask[i], spikes, scalings=scalings, max_consecutive_collisions=max_consecutive_collisions + ) + + +def _plot_one_collision( + we, + overlap, + spikes, + scalings=None, + sparsity=None, + cut_out_samples=100, + max_consecutive_collisions=3, +): + import matplotlib.pyplot as plt + + recording = we.recording + nbefore_nafter_max = max(we.nafter, we.nbefore) + cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + spike_index = overlap[max_consecutive_collisions] + overlap_indices = overlap[overlap != -1] + overlapping_spikes = spikes[overlap_indices] + + if sparsity is not None: + unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices + sparse_indices = np.array([], dtype="int") + for spike in overlapping_spikes: + sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + else: + sparse_indices = np.unique(overlapping_spikes["channel_index"]) + + channel_ids = recording.channel_ids[sparse_indices] + + center_spike = spikes[spike_index] + max_delta = np.max( + [ + np.abs(center_spike["sample_index"] - overlapping_spikes[0]["sample_index"]), + np.abs(center_spike["sample_index"] - overlapping_spikes[-1]["sample_index"]), + ] + ) + sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) + ef = min( + center_spike["sample_index"] + max_delta + cut_out_samples, + recording.get_num_samples(segment_index=center_spike["segment_index"]), + ) + tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) + ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 + max_tr = np.max(np.abs(tr_overlap)) + fig, ax = plt.subplots() + for ch, tr in enumerate(tr_overlap.T): + _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") + ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + + used_labels = [] + for spike in overlapping_spikes: + label = f"U{spike['unit_index']}" + if label in used_labels: + label = None + else: + used_labels.append(label) + ax.axvline( + spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label + ) + + if scalings is not None: + fitted_traces = np.zeros_like(tr_overlap) + + all_templates = we.get_all_templates() + for i, spike in enumerate(overlapping_spikes): + template = all_templates[spike["unit_index"]] + template_scaled = scalings[overlap_indices[i]] * template + template_scaled_sparse = template_scaled[:, sparse_indices] + sample_start = spike["sample_index"] - we.nbefore + sample_end = sample_start + template_scaled_sparse.shape[0] + + fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + + for ch, temp in enumerate(template_scaled_sparse.T): + ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 + _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + + for ch, tr in enumerate(fitted_traces.T): + _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + fitted_line = ax.get_lines()[-1] + fitted_line.set_label("Fitted") -# TODO: fix this! -# def plot_overlapping_spikes(we, overlap, -# spikes, cut_out_samples=100, -# max_consecutive_spikes=3, -# sparsity=None, -# fitted_amps=None): -# recording = we.recording -# nbefore_nafter_max = max(we.nafter, we.nbefore) -# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) -# spike_index = overlap[max_consecutive_spikes] -# overlap_indices = overlap[overlap != -1] -# overlapping_spikes = spikes[overlap_indices] - -# if sparsity is not None: -# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices -# sparse_indices = np.array([], dtype="int") -# for spike in overlapping_spikes: -# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] -# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) -# else: -# sparse_indices = np.unique(overlapping_spikes["channel_index"]) - -# channel_ids = recording.channel_ids[sparse_indices] - -# center_spike = spikes[spike_index]["sample_index"] -# max_delta = np.max([np.abs(center_spike - overlapping_spikes[0]["sample_index"]), -# np.abs(center_spike - overlapping_spikes[-1]["sample_index"])]) -# sf = center_spike - max_delta - cut_out_samples -# ef = center_spike + max_delta + cut_out_samples -# tr_overlap = recording.get_traces(start_frame=sf, -# end_frame=ef, -# channel_ids=channel_ids, return_scaled=True) -# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 -# max_tr = np.max(np.abs(tr_overlap)) -# fig, ax = plt.subplots() -# for ch, tr in enumerate(tr_overlap.T): -# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") -# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") - -# used_labels = [] -# for spike in overlapping_spikes: -# label = f"U{spike['unit_index']}" -# if label in used_labels: -# label = None -# else: -# used_labels.append(label) -# ax.axvline(spike["sample_index"] / recording.sampling_frequency * 1000, -# color=f"C{spike['unit_index']}", label=label) - -# if fitted_amps is not None: -# fitted_traces = np.zeros_like(tr_overlap) - -# all_templates = we.get_all_templates() -# for i, spike in enumerate(overlapping_spikes): -# template = all_templates[spike["unit_index"]] -# template_scaled = fitted_amps[overlap_indices[i]] * template -# template_scaled_sparse = template_scaled[:, sparse_indices] -# sample_start = spike["sample_index"] - we.nbefore -# sample_end = sample_start + template_scaled_sparse.shape[0] - -# fitted_traces[sample_start - sf: sample_end - sf] += template_scaled_sparse - -# for ch, temp in enumerate(template_scaled_sparse.T): - -# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 -# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", -# ls="--") - -# for ch, tr in enumerate(fitted_traces.T): -# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) - -# fitted_line = ax.get_lines()[-1] -# fitted_line.set_label("Fitted") - - -# ax.legend() -# ax.set_title(f"Spike {spike_index} - sample {center_spike}") -# return tr_overlap, ax + ax.legend() + ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") + return ax From 7bec9df5c0298c853f552989ef5e3febcf0f9470 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 28 Aug 2023 17:52:07 +0200 Subject: [PATCH 12/35] Simplify and cleanup --- .../postprocessing/amplitude_scalings.py | 484 ++++++++---------- 1 file changed, 206 insertions(+), 278 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index d367ef4f22..1f7923eb05 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,6 +7,9 @@ from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# DEBUG = True + + class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ Computes amplitude scalings from WaveformExtractor. @@ -21,7 +24,6 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) - self.overlapping_mask = None def _set_params( self, @@ -30,7 +32,6 @@ def _set_params( ms_before, ms_after, handle_collisions, - max_consecutive_collisions, delta_collision_ms, ): params = dict( @@ -39,7 +40,6 @@ def _set_params( ms_before=ms_before, ms_after=ms_after, handle_collisions=handle_collisions, - max_consecutive_collisions=max_consecutive_collisions, delta_collision_ms=delta_collision_ms, ) return params @@ -63,7 +63,6 @@ def _run(self, **job_kwargs): # collisions handle_collisions = self._params["handle_collisions"] - max_consecutive_collisions = self._params["max_consecutive_collisions"] delta_collision_ms = self._params["delta_collision_ms"] delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) @@ -120,7 +119,6 @@ def _run(self, **job_kwargs): cut_out_after, return_scaled, handle_collisions, - max_consecutive_collisions, delta_collision_samples, ) processor = ChunkRecordingExecutor( @@ -133,32 +131,16 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings, overlapping_mask) = zip(*out) + (amp_scalings, collisions) = zip(*out) amp_scalings = np.concatenate(amp_scalings) - if handle_collisions > 0: - from ..core.job_tools import divide_recording_into_chunks - - overlapping_mask_corrected = [] - all_chunks = divide_recording_into_chunks(processor.recording, processor.chunk_size) - num_spikes_so_far = 0 - for i, overlapping in enumerate(overlapping_mask): - if i == 0: - continue - segment_index = all_chunks[i - 1][0] - spikes_in_segment = self.spikes[segment_slices[segment_index]] - i0 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][1]) - i1 = np.searchsorted(spikes_in_segment["sample_index"], all_chunks[i - 1][2]) - num_spikes_so_far += i1 - i0 - overlapping_corrected = overlapping.copy() - overlapping_corrected[overlapping_corrected >= 0] += num_spikes_so_far - overlapping_mask_corrected.append(overlapping_corrected) - overlapping_mask = np.concatenate(overlapping_mask_corrected) - print(f"Found {len(overlapping_mask)} overlapping spikes") - self.overlapping_mask = overlapping_mask - else: - overlapping_mask = np.concatenate(overlapping_mask) + + collisions_dict = {} + if handle_collisions: + for collision in collisions: + collisions_dict.update(collision) self._extension_data[f"amplitude_scalings"] = amp_scalings + self._extension_data[f"collisions"] = collisions_dict def get_data(self, outputs="concatenated"): """ @@ -206,7 +188,6 @@ def compute_amplitude_scalings( ms_before=None, ms_after=None, handle_collisions=False, - max_consecutive_collisions=3, delta_collision_ms=2, load_if_exists=False, outputs="concatenated", @@ -235,10 +216,8 @@ def compute_amplitude_scalings( Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. - max_consecutive_collisions: int, default: 3 - The maximum number of consecutive collisions to handle on each side of a spike. delta_collision_ms: float, default: 2 - The maximum time difference in ms between two spikes to be considered as colliding. + The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. outputs: str, default: 'concatenated' @@ -264,7 +243,6 @@ def compute_amplitude_scalings( ms_before=ms_before, ms_after=ms_after, handle_collisions=handle_collisions, - max_consecutive_collisions=max_consecutive_collisions, delta_collision_ms=delta_collision_ms, ) sac.run(**job_kwargs) @@ -288,7 +266,6 @@ def _init_worker_amplitude_scalings( cut_out_after, return_scaled, handle_collisions, - max_consecutive_collisions, delta_collision_samples, ): # create a local dict per worker @@ -304,15 +281,15 @@ def _init_worker_amplitude_scalings( worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices worker_ctx["handle_collisions"] = handle_collisions - worker_ctx["max_consecutive_collisions"] = max_consecutive_collisions worker_ctx["delta_collision_samples"] = delta_collision_samples if not handle_collisions: worker_ctx["margin"] = max(nbefore, nafter) else: + # in this case we extend the margin to be able to get with collisions outside the chunk margin_waveforms = max(nbefore, nafter) - max_margin_collisions = int(max_consecutive_collisions * delta_collision_samples) - worker_ctx["margin"] = max(margin_waveforms, max_margin_collisions) + max_margin_collisions = delta_collision_samples + margin_waveforms + worker_ctx["margin"] = max_margin_collisions return worker_ctx @@ -332,7 +309,6 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] handle_collisions = worker_ctx["handle_collisions"] - max_consecutive_collisions = worker_ctx["max_consecutive_collisions"] delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] @@ -355,17 +331,21 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) # set colliding spikes apart (if needed) if handle_collisions: - overlapping_mask = find_overlapping_mask( - local_spikes, max_consecutive_collisions, delta_collision_samples, unit_inds_to_channel_indices + # local spikes with margin! + i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) + i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] + collisions = find_collisions( + local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices ) - overlapping_spike_indices = overlapping_mask[:, max_consecutive_collisions] else: - overlapping_spike_indices = np.array([], dtype=int) + collisions = {} - # get all waveforms + # compute the scaling for each spike scalings = np.zeros(len(local_spikes), dtype=float) + collisions_dict = {} for spike_index, spike in enumerate(local_spikes): - if spike_index in overlapping_spike_indices: + if spike_index in collisions.keys(): # we deal with overlapping spikes later continue unit_index = spike["unit_index"] @@ -390,14 +370,13 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) scalings[spike_index] = linregress_res[0] # deal with collisions - if len(overlapping_spike_indices) > 0: - for overlapping in overlapping_mask: - # the current spike is the one at the 'max_consecutive_collisions' position - spike_index = overlapping[max_consecutive_collisions] - overlapping_spikes = local_spikes[overlapping[overlapping >= 0]] - current_spike_index_within_overlapping = np.where(overlapping >= 0)[0] == max_consecutive_collisions + if len(collisions) > 0: + num_spikes_in_previous_segments = int( + np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) + ) + for spike_index, collision in collisions.items(): scaled_amps = fit_collision( - overlapping_spikes, + collision, traces_with_margin, start_frame, end_frame, @@ -409,13 +388,16 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_before, cut_out_after, ) - # get the right amplitude scaling - scalings[spike_index] = scaled_amps[current_spike_index_within_overlapping] + # the scaling for the current spike is at index 0 + scalings[spike_index] = scaled_amps[0] + + # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments + collisions_dict.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) else: scalings = np.array([]) - overlapping_mask = np.array([], shape=(0, max_consecutive_collisions + 1)) + collisions_dict = {} - return (scalings, overlapping_mask) + return (scalings, collisions_dict) ### Collision handling ### @@ -443,111 +425,65 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): return False -def find_overlapping_mask(spikes, max_consecutive_spikes, delta_overlap_samples, unit_inds_to_channel_indices): +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): """ - Finds the overlapping spikes for each spike in spikes and returns a boolean mask of shape - (n_spikes, 2 * max_consecutive_spikes + 1). + Finds the collisions between spikes. Parameters ---------- spikes: np.array An array of spikes - max_consecutive_spikes: int - The maximum number of consecutive spikes to consider - delta_overlap_samples: int + spikes_w_margin: np.array + An array of spikes within the added margin + delta_collision_samples: int The maximum number of samples between two spikes to consider them as overlapping unit_inds_to_channel_indices: dict A dictionary mapping unit indices to channel indices Returns ------- - overlapping_mask: np.array - A boolean mask of shape (n_spikes, 2 * max_consecutive_spikes + 1) where the central column (max_consecutive_spikes) - is the current spike index, while the other columns are the indices of the overlapping spikes. The first - max_consecutive_spikes columns are the pre-overlapping spikes, while the last max_consecutive_spikes columns are - the post-overlapping spikes. + collision_spikes_dict: np.array + A dictionary with collisions. The key is the index of the spike with collision, the value is an + array of overlapping spikes, including the spike itself at position 0. """ + collision_spikes_dict = {} + for spike_index, spike in enumerate(spikes): + # find the index of the spike within the spikes_w_margin + spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] - # overlapping_mask is a matrix of shape (n_spikes, 2 * max_consecutive_spikes + 1) - # the central column (max_consecutive_spikes) is the current spike index, while the other columns are the - # indices of the overlapping spikes. The first max_consecutive_spikes columns are the pre-overlapping spikes, - # while the last max_consecutive_spikes columns are the post-overlapping spikes - # Rows with all -1 are non-colliding spikes and are removed later - overlapping_mask_full = -1 * np.ones((len(spikes), 2 * max_consecutive_spikes + 1), dtype=int) - overlapping_mask_full[:, max_consecutive_spikes] = np.arange(len(spikes)) - - for i, spike in enumerate(spikes): - # find the possible spikes per and post within max_consecutive_spikes * delta_overlap_samples + # find the possible spikes per and post within delta_collision_samples consecutive_window_pre = np.searchsorted( - spikes["sample_index"], - spike["sample_index"] - max_consecutive_spikes * delta_overlap_samples, + spikes_w_margin["sample_index"], + spike["sample_index"] - delta_collision_samples, ) consecutive_window_post = np.searchsorted( - spikes["sample_index"], - spike["sample_index"] + max_consecutive_spikes * delta_overlap_samples, + spikes_w_margin["sample_index"], + spike["sample_index"] + delta_collision_samples, + ) + # exclude the spike itself (it is included in the collision_spikes by construction) + pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) + post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) + possible_overlapping_spike_indices = np.concatenate( + (pre_possible_consecutive_spike_indices, post_possible_consecutive_spike_indices) ) - pre_possible_consecutive_spikes = spikes[consecutive_window_pre:i][::-1] - post_possible_consecutive_spikes = spikes[i + 1 : consecutive_window_post] - - # here we fill in the overlapping information by consecutively looping through the possible consecutive spikes - # and checking the spatial overlap and the delay with the previous overlapping spike - # pre and post are hanlded separately. Note that the pre-spikes are already sorted backwards - - # overlap_rank keeps track of the rank of consecutive collisions (i.e., rank 0 is the first, rank 1 is the second, etc.) - # this is needed because we are just considering spikes with spatial overlap, while the possible consecutive spikes - # only looked at the temporal overlap - overlap_rank = 0 - if len(pre_possible_consecutive_spikes) > 0: - for c_pre, spike_consecutive_pre in enumerate(pre_possible_consecutive_spikes[::-1]): - if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_pre["unit_index"] - ): - if ( - spikes[overlapping_mask_full[i, max_consecutive_spikes - overlap_rank]]["sample_index"] - - spike_consecutive_pre["sample_index"] - < delta_overlap_samples - ): - overlapping_mask_full[i, max_consecutive_spikes - overlap_rank - 1] = i - 1 - c_pre - overlap_rank += 1 - else: - break - # if overlap_rank > 1: - # print(f"\tHigher order pre-overlap for spike {i}!") - - overlap_rank = 0 - if len(post_possible_consecutive_spikes) > 0: - for c_post, spike_consecutive_post in enumerate(post_possible_consecutive_spikes): - if _are_unit_indices_overlapping( - unit_inds_to_channel_indices, spike["unit_index"], spike_consecutive_post["unit_index"] - ): - if ( - spike_consecutive_post["sample_index"] - - spikes[overlapping_mask_full[i, max_consecutive_spikes + overlap_rank]]["sample_index"] - < delta_overlap_samples - ): - overlapping_mask_full[i, max_consecutive_spikes + overlap_rank + 1] = i + 1 + c_post - overlap_rank += 1 - else: - break - # if overlap_rank > 1: - # print(f"\tHigher order post-overlap for spike {i}!") - - # in case no collisions were found, we set the central column to -1 so that we can easily identify the non-colliding spikes - if np.sum(overlapping_mask_full[i] != -1) == 1: - overlapping_mask_full[i, max_consecutive_spikes] = -1 - - # only return rows with collisions - overlapping_inds = [] - for i, overlapping in enumerate(overlapping_mask_full): - if np.any(overlapping >= 0): - overlapping_inds.append(i) - overlapping_mask = overlapping_mask_full[overlapping_inds] - - return overlapping_mask + + # find the overlapping spikes in space as well + for possible_overlapping_spike_index in possible_overlapping_spike_indices: + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, + spike["unit_index"], + spikes_w_margin[possible_overlapping_spike_index]["unit_index"], + ): + if spike_index not in collision_spikes_dict: + collision_spikes_dict[spike_index] = np.array([spike]) + collision_spikes_dict[spike_index] = np.concatenate( + (collision_spikes_dict[spike_index], [spikes_w_margin[possible_overlapping_spike_index]]) + ) + return collision_spikes_dict def fit_collision( - overlapping_spikes, + collision, traces_with_margin, start_frame, end_frame, @@ -558,15 +494,16 @@ def fit_collision( unit_inds_to_channel_indices, cut_out_before, cut_out_after, - debug=True, ): """ Compute the best fit for a collision between a spike and its overlapping spikes. + The function first cuts out the traces around the spike and its overlapping spikes, then + fits a multi-linear regression model to the traces using the centered templates as predictors. Parameters ---------- - overlapping_spikes: np.ndarray - A numpy array of shape (n_overlapping_spikes, ) containing the overlapping spikes (spike_dtype). + collision: np.ndarray + A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype). traces_with_margin: np.ndarray A numpy array of shape (n_samples, n_channels) containing the traces with a margin. start_frame: int @@ -591,30 +528,30 @@ def fit_collision( Returns ------- np.ndarray - The fitted scaling factors for the overlapping spikes. + The fitted scaling factors for the colliding spikes. """ from sklearn.linear_model import LinearRegression - sample_first_centered = overlapping_spikes[0]["sample_index"] - start_frame - left - sample_last_centered = overlapping_spikes[-1]["sample_index"] - start_frame - left + # make center of the spike externally + sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) + sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) # construct sparsity as union between units' sparsity sparse_indices = np.array([], dtype="int") - for spike in overlapping_spikes: + for spike in collision: sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] sparse_indices = np.union1d(sparse_indices, sparse_indices_i) - # TODO: check alignment!!! local_waveform_start = max(0, sample_first_centered - cut_out_before) local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] y = local_waveform.T.flatten() - X = np.zeros((len(y), len(overlapping_spikes))) - for i, spike in enumerate(overlapping_spikes): + X = np.zeros((len(y), len(collision))) + for i, spike in enumerate(collision): full_template = np.zeros_like(local_waveform) # center wrt cutout traces - sample_centered = spike["sample_index"] - start_frame - left - local_waveform_start + sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start template = all_templates[spike["unit_index"]][:, sparse_indices] template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] # deal with borders @@ -626,136 +563,127 @@ def fit_collision( full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut X[:, i] = full_template.T.flatten() - if debug: - import matplotlib.pyplot as plt - - fig, ax = plt.subplots() - max_tr = np.max(np.abs(local_waveform)) - - _ = ax.plot(y, color="k") - - for i, spike in enumerate(overlapping_spikes): - _ = ax.plot(X[:, i], color=f"C{i}", alpha=0.5) - plt.show() - reg = LinearRegression().fit(X, y) scalings = reg.coef_ return scalings -def plot_collisions(we, sparsity=None, num_collisions=None): - """ - Plot the fitting of collision spikes. - - Parameters - ---------- - we : WaveformExtractor - The WaveformExtractor object. - sparsity : ChannelSparsity, default=None - The ChannelSparsity. If None, only main channels are plotted. - num_collisions : int, default=None - Number of collisions to plot. If None, all collisions are plotted. - """ - assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" - sac = we.load_extension("amplitude_scalings") - handle_collisions = sac._params["handle_collisions"] - assert handle_collisions, "Amplitude scalings was run without handling collisions!" - scalings = sac.get_data() - - overlapping_mask = sac.overlapping_mask - num_collisions = num_collisions or len(overlapping_mask) - spikes = sac.spikes - max_consecutive_collisions = sac._params["max_consecutive_collisions"] - - for i in range(num_collisions): - ax = _plot_one_collision( - we, overlapping_mask[i], spikes, scalings=scalings, max_consecutive_collisions=max_consecutive_collisions - ) - - -def _plot_one_collision( - we, - overlap, - spikes, - scalings=None, - sparsity=None, - cut_out_samples=100, - max_consecutive_collisions=3, -): - import matplotlib.pyplot as plt - - recording = we.recording - nbefore_nafter_max = max(we.nafter, we.nbefore) - cut_out_samples = max(cut_out_samples, nbefore_nafter_max) - spike_index = overlap[max_consecutive_collisions] - overlap_indices = overlap[overlap != -1] - overlapping_spikes = spikes[overlap_indices] - - if sparsity is not None: - unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices - sparse_indices = np.array([], dtype="int") - for spike in overlapping_spikes: - sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] - sparse_indices = np.union1d(sparse_indices, sparse_indices_i) - else: - sparse_indices = np.unique(overlapping_spikes["channel_index"]) - - channel_ids = recording.channel_ids[sparse_indices] - - center_spike = spikes[spike_index] - max_delta = np.max( - [ - np.abs(center_spike["sample_index"] - overlapping_spikes[0]["sample_index"]), - np.abs(center_spike["sample_index"] - overlapping_spikes[-1]["sample_index"]), - ] - ) - sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) - ef = min( - center_spike["sample_index"] + max_delta + cut_out_samples, - recording.get_num_samples(segment_index=center_spike["segment_index"]), - ) - tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) - ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 - max_tr = np.max(np.abs(tr_overlap)) - fig, ax = plt.subplots() - for ch, tr in enumerate(tr_overlap.T): - _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") - ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") - - used_labels = [] - for spike in overlapping_spikes: - label = f"U{spike['unit_index']}" - if label in used_labels: - label = None - else: - used_labels.append(label) - ax.axvline( - spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label - ) - - if scalings is not None: - fitted_traces = np.zeros_like(tr_overlap) - - all_templates = we.get_all_templates() - for i, spike in enumerate(overlapping_spikes): - template = all_templates[spike["unit_index"]] - template_scaled = scalings[overlap_indices[i]] * template - template_scaled_sparse = template_scaled[:, sparse_indices] - sample_start = spike["sample_index"] - we.nbefore - sample_end = sample_start + template_scaled_sparse.shape[0] - - fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse - - for ch, temp in enumerate(template_scaled_sparse.T): - ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 - _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") - - for ch, tr in enumerate(fitted_traces.T): - _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) - - fitted_line = ax.get_lines()[-1] - fitted_line.set_label("Fitted") - - ax.legend() - ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") - return ax +# uncomment for debugging +# def plot_collisions(we, sparsity=None, num_collisions=None): +# """ +# Plot the fitting of collision spikes. + +# Parameters +# ---------- +# we : WaveformExtractor +# The WaveformExtractor object. +# sparsity : ChannelSparsity, default=None +# The ChannelSparsity. If None, only main channels are plotted. +# num_collisions : int, default=None +# Number of collisions to plot. If None, all collisions are plotted. +# """ +# assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" +# sac = we.load_extension("amplitude_scalings") +# handle_collisions = sac._params["handle_collisions"] +# assert handle_collisions, "Amplitude scalings was run without handling collisions!" +# scalings = sac.get_data() + +# # overlapping_mask = sac.overlapping_mask +# # num_collisions = num_collisions or len(overlapping_mask) +# spikes = sac.spikes +# collisions = sac._extension_data[f"collisions"] +# collision_keys = list(collisions.keys()) +# num_collisions = num_collisions or len(collisions) +# num_collisions = min(num_collisions, len(collisions)) + +# for i in range(num_collisions): +# overlapping_spikes = collisions[collision_keys[i]] +# ax = _plot_one_collision( +# we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity +# ) + + +# def _plot_one_collision( +# we, +# spike_index, +# overlapping_spikes, +# spikes, +# scalings=None, +# sparsity=None, +# cut_out_samples=100, +# ): +# import matplotlib.pyplot as plt + +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = overlapping_spikes[0] +# max_delta = np.max( +# [ +# np.abs(center_spike["sample_index"] - np.min(overlapping_spikes[1:]["sample_index"])), +# np.abs(center_spike["sample_index"] - np.max(overlapping_spikes[1:]["sample_index"])), +# ] +# ) +# sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) +# ef = min( +# center_spike["sample_index"] + max_delta + cut_out_samples, +# recording.get_num_samples(segment_index=center_spike["segment_index"]), +# ) +# tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) +# fig, ax = plt.subplots() +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for i, spike in enumerate(overlapping_spikes): +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline( +# spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label +# ) + +# if scalings is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# overlap_index = np.where(spikes == spike)[0][0] +# template_scaled = scalings[overlap_index] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") +# return ax From 2d7b08f2744c550dd630add451e85c28f4f7336d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 09:55:09 +0200 Subject: [PATCH 13/35] Add zugbruecke in extractors install for plexon2 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 3ecfbe2718..ddb0de4893 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ extractors = [ "ONE-api>=1.19.1", "ibllib>=2.21.0", "pymatreader>=0.0.32", # For cell explorer matlab files + "zugbruecke", # For plexon2 ] streaming_extractors = [ From 4a9f429f2ad3b18057db6c6432960c401f0ff14c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 09:57:16 +0200 Subject: [PATCH 14/35] Update naming following #1626 --- src/spikeinterface/extractors/neoextractors/plexon2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index c3869dbadc..148deb48e9 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -52,7 +52,7 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): mode = "file" NeoRawIOClass = "Plexon2RawIO" - handle_spike_frame_directly = True + neo_returns_frames = True name = "plexon2" def __init__(self, file_path): From aae39c6d1a2f4c3f952e19a81e031eb7abb909ae Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 10:34:17 +0200 Subject: [PATCH 15/35] Install wine for plexon2 --- .github/actions/build-test-environment/action.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 004fe31203..7b5debdd51 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -40,3 +40,9 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH shell: bash + - name: Install wine (needed for Plexon2) + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo dpkg --add-architecture i386 + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine From 26fdba2b0d85f5c174f60462d6edb6128876c14a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 10:36:23 +0200 Subject: [PATCH 16/35] Add shell --- .github/actions/build-test-environment/action.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index 7b5debdd51..b056bd3353 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -46,3 +46,4 @@ runs: sudo dpkg --add-architecture i386 sudo apt-get update -qq sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine + shell: bash From 8bcec5f9df9d2205b0ecd222aac5df135492d730 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 11:39:12 +0200 Subject: [PATCH 17/35] Expose sampling_frequency in pl2 sorting (needed for multi-stream) --- src/spikeinterface/extractors/neoextractors/plexon2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 148deb48e9..966fc253ad 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -48,6 +48,8 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): ---------- file_path: str The file path to load the recordings from. + sampling_frequency: float, default: None + The sampling frequency of the sorting (required for multiple streams with different sampling frequencies). """ mode = "file" @@ -55,13 +57,13 @@ class Plexon2SortingExtractor(NeoBaseSortingExtractor): neo_returns_frames = True name = "plexon2" - def __init__(self, file_path): + def __init__(self, file_path, sampling_frequency=None): from neo.rawio import Plexon2RawIO neo_kwargs = self.map_to_neo_kwargs(file_path) neo_reader = Plexon2RawIO(**neo_kwargs) neo_reader.parse_header() - NeoBaseSortingExtractor.__init__(self, **neo_kwargs) + NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) self._kwargs.update({"file_path": str(file_path)}) @classmethod From 3ffb76c444e2556fd62efbfab677d6dcd1cd7706 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 11:59:55 +0200 Subject: [PATCH 18/35] Add sampling_frequency kwargs in tests --- src/spikeinterface/extractors/tests/test_neoextractors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 0405d7b129..e8f565bede 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -156,9 +156,7 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] - entities = [ - ("plexon/4chDemoPL2.pl2"), - ] + entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): @@ -328,7 +326,7 @@ def test_pickling(self): # test = PlexonRecordingTest() # test = PlexonSortingTest() # test = NeuralynxRecordingTest() - test = BlackrockRecordingTest() + test = Plexon2RecordingTest() # test = MCSRawRecordingTest() # test = KiloSortSortingTest() # test = Spike2RecordingTest() From 6247dc090b5604dca5dd73fb151f55f781bca8d3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 12:55:33 +0200 Subject: [PATCH 19/35] Update self._kwargs --- src/spikeinterface/extractors/neoextractors/plexon2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py index 966fc253ad..8dbfc67e90 100644 --- a/src/spikeinterface/extractors/neoextractors/plexon2.py +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -64,7 +64,7 @@ def __init__(self, file_path, sampling_frequency=None): neo_reader = Plexon2RawIO(**neo_kwargs) neo_reader.parse_header() NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) - self._kwargs.update({"file_path": str(file_path)}) + self._kwargs.update({"file_path": str(file_path), "sampling_frequency": sampling_frequency}) @classmethod def map_to_neo_kwargs(cls, file_path): From dfcd3caf8a18168ae05b564d62e7ce15c3ac185d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 16:46:55 +0200 Subject: [PATCH 20/35] Mark plexon2 tests and install Wine only if needed --- .../actions/build-test-environment/action.yml | 7 --- .github/actions/install-wine/action.yml | 21 ++++++++ .github/workflows/full-test-with-codecov.yml | 2 +- .github/workflows/full-test.yml | 12 ++++- .../extractors/tests/test_neoextractors.py | 48 ++++++++++--------- 5 files changed, 59 insertions(+), 31 deletions(-) create mode 100644 .github/actions/install-wine/action.yml diff --git a/.github/actions/build-test-environment/action.yml b/.github/actions/build-test-environment/action.yml index b056bd3353..004fe31203 100644 --- a/.github/actions/build-test-environment/action.yml +++ b/.github/actions/build-test-environment/action.yml @@ -40,10 +40,3 @@ runs: tar xvzf git-annex-standalone-amd64.tar.gz echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH shell: bash - - name: Install wine (needed for Plexon2) - run: | - sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list - sudo dpkg --add-architecture i386 - sudo apt-get update -qq - sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine - shell: bash diff --git a/.github/actions/install-wine/action.yml b/.github/actions/install-wine/action.yml new file mode 100644 index 0000000000..3ae08ecd34 --- /dev/null +++ b/.github/actions/install-wine/action.yml @@ -0,0 +1,21 @@ +name: Install packages +description: This action installs the package and its dependencies for testing + +inputs: + python-version: + description: 'Python version to set up' + required: false + os: + description: 'Operating system to set up' + required: false + +runs: + using: "composite" + steps: + - name: Install wine (needed for Plexon2) + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo dpkg --add-architecture i386 + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine + shell: bash diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index a5561c2ffc..d0bf109a00 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -54,7 +54,7 @@ jobs: - name: run tests run: | source ${{ github.workspace }}/test_env/bin/activate - pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 + pytest -m "not sorters_external and not plexon2" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index ac5130bade..a343500c08 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -75,6 +75,10 @@ jobs: echo "Extractors changed" echo "EXTRACTORS_CHANGED=true" >> $GITHUB_OUTPUT fi + if [[ $file == *"plexon2"* ]]; then + echo "Plexon2 changed" + echo "PLEXON2_CHANGED=true" >> $GITHUB_OUTPUT + fi if [[ $file == *"/preprocessing/"* ]]; then echo "Preprocessing changed" echo "PREPROCESSING_CHANGED=true" >> $GITHUB_OUTPUT @@ -122,11 +126,14 @@ jobs: done - name: Set execute permissions on run_tests.sh run: chmod +x .github/run_tests.sh + - name: Install Wine (Plexon2) + if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} + uses: ./.github/actions/install-wine - name: Test core run: ./.github/run_tests.sh core - name: Test extractors if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} - run: ./.github/run_tests.sh "extractors and not streaming_extractors" + run: ./.github/run_tests.sh "extractors and not streaming_extractors and not plexon2" - name: Test preprocessing if: ${{ steps.modules-changed.outputs.PREPROCESSING_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh preprocessing @@ -157,3 +164,6 @@ jobs: - name: Test internal sorters if: ${{ steps.modules-changed.outputs.SORTERS_INTERNAL_CHANGED == 'true' || steps.modules-changed.outputs.SORTINGCOMPONENTS_CHANGED || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh sorters_internal + - name: Test plexon2 + if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} + run: ./.github/run_tests.sh plexon2 diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index e8f565bede..da162eccf1 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -121,14 +121,6 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase): ] -class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): - ExtractorClass = Plexon2EventExtractor - downloads = ["plexon"] - entities = [ - ("plexon/4chDemoPL2.pl2"), - ] - - class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonRecordingExtractor downloads = ["plexon"] @@ -137,14 +129,6 @@ class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] -class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): - ExtractorClass = Plexon2RecordingExtractor - downloads = ["plexon"] - entities = [ - ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), - ] - - class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = PlexonSortingExtractor downloads = ["plexon"] @@ -153,12 +137,6 @@ class PlexonSortingTest(SortingCommonTestSuite, unittest.TestCase): ] -class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): - ExtractorClass = Plexon2SortingExtractor - downloads = ["plexon"] - entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] - - class NeuralynxRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuralynxRecordingExtractor downloads = ["neuralynx"] @@ -312,6 +290,32 @@ def test_pickling(self): pass +# We mark plexon2 tests as they require additional dependencies (wine) +@pytest.mark.plexon2 +class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2RecordingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ] + + +@pytest.mark.plexon2 +class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2EventExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] + + +@pytest.mark.plexon2 +class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2SortingExtractor + downloads = ["plexon"] + entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] + + if __name__ == "__main__": # test = MearecSortingTest() # test = SpikeGLXRecordingTest() From 23f8677f148a819a57c1f859aa327ca40d124f25 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 29 Aug 2023 17:33:15 +0200 Subject: [PATCH 21/35] Install zugbruecke not on win --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ddb0de4893..b0bf4cdcf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ extractors = [ "ONE-api>=1.19.1", "ibllib>=2.21.0", "pymatreader>=0.0.32", # For cell explorer matlab files - "zugbruecke", # For plexon2 + "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ From 6f902d46dad68d5f69322175537e7a89b4e0bd43 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 30 Aug 2023 11:11:39 +0200 Subject: [PATCH 22/35] patch --- src/spikeinterface/preprocessing/silence_periods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 223122e927..c2ffcc6843 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -46,7 +46,7 @@ def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs) num_seg = recording.get_num_segments() if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and not np.isscalar(list_periods[0]): + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: # when unique segment accept list instead of of list of list/arrays list_periods = [list_periods] From a797aa33c561871b10b4a441985d89546e8ebc2e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 30 Aug 2023 15:55:16 +0200 Subject: [PATCH 23/35] Improve debug plots and handle_collisions=True by default --- .../postprocessing/amplitude_scalings.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 1f7923eb05..a9b3898388 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -187,7 +187,7 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, - handle_collisions=False, + handle_collisions=True, delta_collision_ms=2, load_if_exists=False, outputs="concatenated", @@ -212,7 +212,7 @@ def compute_amplitude_scalings( ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. If None, the WaveformExtractor ms_after is used. - handle_collisions: bool, default: False + handle_collisions: bool, default: True Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. @@ -598,12 +598,12 @@ def fit_collision( # for i in range(num_collisions): # overlapping_spikes = collisions[collision_keys[i]] -# ax = _plot_one_collision( +# ax = plot_one_collision( # we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity # ) -# def _plot_one_collision( +# def plot_one_collision( # we, # spike_index, # overlapping_spikes, @@ -611,9 +611,13 @@ def fit_collision( # scalings=None, # sparsity=None, # cut_out_samples=100, +# ax=None # ): # import matplotlib.pyplot as plt +# if ax is None: +# fig, ax = plt.subplots() + # recording = we.recording # nbefore_nafter_max = max(we.nafter, we.nbefore) # cut_out_samples = max(cut_out_samples, nbefore_nafter_max) @@ -644,7 +648,7 @@ def fit_collision( # tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) # ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 # max_tr = np.max(np.abs(tr_overlap)) -# fig, ax = plt.subplots() + # for ch, tr in enumerate(tr_overlap.T): # _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") # ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") From 287e8af9621385d4fa835be6356b7695993cdc16 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 Sep 2023 16:46:18 +0200 Subject: [PATCH 24/35] Fix tests --- src/spikeinterface/postprocessing/amplitude_scalings.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 0dd2587fba..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -24,6 +24,7 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) + self.collisions = None def _set_params( self, @@ -138,9 +139,11 @@ def _run(self, **job_kwargs): if handle_collisions: for collision in collisions: collisions_dict.update(collision) + self.collisions = collisions_dict + # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - self._extension_data[f"amplitude_scalings"] = amp_scalings - self._extension_data[f"collisions"] = collisions_dict + self._extension_data["amplitude_scalings"] = amp_scalings def get_data(self, outputs="concatenated"): """ From 34538311d1c00f5f45b4cb84a03984efa9ef4f3d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 7 Sep 2023 18:38:25 +0200 Subject: [PATCH 25/35] fedeback from alessio and ramon --- src/spikeinterface/core/waveform_tools.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 252ea68738..53c0df68df 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -258,8 +258,8 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _worker_ditribute_buffers - init_func = _init_worker_ditribute_buffers + func = _worker_distribute_buffers + init_func = _init_worker_distribute_buffers init_args = ( recording, @@ -283,7 +283,7 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_ditribute_buffers( +def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker @@ -329,7 +329,7 @@ def _init_worker_ditribute_buffers( # used by ChunkRecordingExecutor -def _worker_ditribute_buffers(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -480,7 +480,7 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"all_waveforms.npy") + filename = str(folder / f"waveforms.npy") all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) wf_array_info = filename elif mode == "shared_memory": @@ -497,15 +497,10 @@ def extract_waveforms_to_single_buffer( job_kwargs = fix_job_kwargs(job_kwargs) - inds_by_unit = {} - for unit_ind, unit_id in enumerate(unit_ids): - (inds,) = np.nonzero(spikes["unit_index"] == unit_ind) - inds_by_unit[unit_id] = inds - if num_spikes > 0: # and run - func = _worker_ditribute_single_buffer - init_func = _init_worker_ditribute_single_buffer + func = _worker_distribute_single_buffer + init_func = _init_worker_distribute_single_buffer init_args = ( recording, @@ -537,7 +532,7 @@ def extract_waveforms_to_single_buffer( return all_waveforms, wf_array_info -def _init_worker_ditribute_single_buffer( +def _init_worker_distribute_single_buffer( recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} @@ -576,7 +571,7 @@ def _init_worker_ditribute_single_buffer( # used by ChunkRecordingExecutor -def _worker_ditribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] From d79dbe26bb1d3f5db45b1ac36d84f7b96be08f18 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 7 Sep 2023 18:51:26 +0200 Subject: [PATCH 26/35] extract_waveforms_to_single_buffer change folder to file_path --- .../core/tests/test_waveform_tools.py | 27 ++++++++++++------- src/spikeinterface/core/waveform_tools.py | 17 +++++------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 52d7472c92..1d7e38832a 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -59,14 +59,15 @@ def test_waveform_tools(): ] some_modes = [ {"mode": "memmap"}, + {"mode": "shared_memory"}, ] - if platform.system() != "Windows": - # shared memory on windows is buggy... - some_modes.append( - { - "mode": "shared_memory", - } - ) + # if platform.system() != "Windows": + # # shared memory on windows is buggy... + # some_modes.append( + # { + # "mode": "shared_memory", + # } + # ) some_sparsity = [ dict(sparsity_mask=None), @@ -87,9 +88,11 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir(parents=True) - mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) - else: - mode_kwargs_ = mode_kwargs + wf_file_path = wf_folder / "waveforms_all_units.npy" + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["folder" ] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -113,6 +116,10 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["file_path" ] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 53c0df68df..c363ac49dc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -406,7 +406,7 @@ def extract_waveforms_to_single_buffer( nafter, mode="memmap", return_scaled=False, - folder=None, + file_path=None, dtype=None, sparsity_mask=None, copy=False, @@ -442,8 +442,8 @@ def extract_waveforms_to_single_buffer( Mode to use ('memmap' | 'shared_memory') return_scaled: bool Scale traces before exporting to buffer or not. - folder: str or path - In case of memmap mode, folder to save npy files + file_path: str or path + In case of memmap mode, file to save npy file. dtype: numpy.dtype dtype for waveforms buffer sparsity_mask: None or array of bool @@ -468,9 +468,9 @@ def extract_waveforms_to_single_buffer( dtype = np.dtype(dtype) if mode == "shared_memory": - assert folder is None + assert file_path is None else: - folder = Path(folder) + file_path = Path(file_path) num_spikes = spikes.size if sparsity_mask is None: @@ -480,9 +480,8 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"waveforms.npy") - all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wf_array_info = filename + all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) + wf_array_info = str(file_path) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -538,7 +537,6 @@ def _init_worker_distribute_single_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter @@ -574,7 +572,6 @@ def _init_worker_distribute_single_buffer( def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"] From 5c6615975704bd9bbbda722291d04ff0ccfb3c90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:51:54 +0000 Subject: [PATCH 27/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_waveform_tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 1d7e38832a..e9cf1bfb5f 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -92,7 +92,7 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs) if mode_kwargs["mode"] == "memmap": - mode_kwargs_["folder" ] = wf_folder + mode_kwargs_["folder"] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -118,8 +118,8 @@ def test_waveform_tools(): mode_kwargs_ = dict(**mode_kwargs) if mode_kwargs["mode"] == "memmap": - mode_kwargs_["file_path" ] = wf_file_path - + mode_kwargs_["file_path"] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, From 7080696f12617bfd08769c89a8471768878191ca Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Fri, 8 Sep 2023 10:17:48 +0200 Subject: [PATCH 28/35] Update src/spikeinterface/core/waveform_tools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/core/waveform_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index c363ac49dc..a63d0a80b7 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -417,7 +417,7 @@ def extract_waveforms_to_single_buffer( Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is - needed to recover waveforms unit by unit. Importantly in case of sparsity, the channel are not aligned across + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. Important note: for the "shared_memory" mode wf_array_info contains reference to From 99602f17ed4793bbf21b577cd8f87860cc3c3c2b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 11 Sep 2023 10:46:51 +0200 Subject: [PATCH 29/35] Make plexon2 tests conditional on Wine dependency (on Linux) --- .../extractors/tests/test_neoextractors.py | 40 ++++++++++++++++++- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index da162eccf1..5fe42b0c4e 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -1,5 +1,6 @@ import unittest -from platform import python_version +import platform +import subprocess from packaging import version import pytest @@ -18,6 +19,38 @@ local_folder = get_global_dataset_folder() / "ephy_testing_data" +def has_plexon2_dependencies(): + """ + Check if required Plexon2 dependencies are installed on different OS. + """ + + os_type = platform.system() + + if os_type == "Windows": + # On Windows, no need for additional dependencies + return True + + elif os_type == "Linux": + # Check for 'wine' using dpkg + try: + result_wine = subprocess.run( + ["dpkg", "-l", "wine"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + except subprocess.CalledProcessError: + return False + + # Check for 'zugbruecke' using pip + try: + import zugbruecke + + return True + except ImportError: + return False + else: + # Not sure about MacOS + raise ValueError(f"Unsupported OS: {os_type}") + + class MearecRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MEArecRecordingExtractor downloads = ["mearec"] @@ -218,7 +251,7 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif( - version.parse(python_version()) >= version.parse("3.10"), + version.parse(platform.python_version()) >= version.parse("3.10"), reason="Sonpy only testing with Python < 3.10!", ) class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): @@ -291,6 +324,7 @@ def test_pickling(self): # We mark plexon2 tests as they require additional dependencies (wine) +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") @pytest.mark.plexon2 class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor @@ -300,6 +334,7 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") @pytest.mark.plexon2 class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor @@ -309,6 +344,7 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") @pytest.mark.plexon2 class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor From d7d34c4c676e49724da0c79aab2c7605865473bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:49:24 +0000 Subject: [PATCH 30/35] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 23.7.0 → 23.9.1](https://github.com/psf/black/compare/23.7.0...23.9.1) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ced1ee6a2f..07601cd208 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black files: ^src/ From e5a523c9263fa1a229e89905639496da03dd39e0 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 10:12:54 +0200 Subject: [PATCH 31/35] Improvement after Ramon comments. --- src/spikeinterface/core/waveform_tools.py | 67 +++++++++++------------ 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a63d0a80b7..6e0d6f412b 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -350,16 +350,10 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -420,6 +414,9 @@ def extract_waveforms_to_single_buffer( needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across units. + Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. + This ensures that spikes.shape[0] == all_waveforms.shape[0]. + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -449,10 +446,13 @@ def extract_waveforms_to_single_buffer( sparsity_mask: None or array of bool If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool - If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then arrays_info is also return. Please keep in mind that arrays_info - need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. - Also when copy=False the SharedMemory will need to be unlink manually + If True (default), the output shared memory object is copied to a numpy standard array and no reference + to the internal shared memory object is kept. + If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object + need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation + faults which are hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired. + {} Returns @@ -481,7 +481,8 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - wf_array_info = str(file_path) + # wf_array_info = str(file_path) + wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -490,7 +491,8 @@ def extract_waveforms_to_single_buffer( else: all_waveforms, shm = make_shared_array(shape, dtype) shm_name = shm.name - wf_array_info = (shm, shm_name, dtype.str, shape) + # wf_array_info = (shm, shm_name, dtype.str, shape) + wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) else: raise ValueError("allocate_waveforms_buffers bad mode") @@ -503,7 +505,6 @@ def extract_waveforms_to_single_buffer( init_args = ( recording, - unit_ids, spikes, wf_array_info, nbefore, @@ -532,7 +533,7 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( - recording, unit_ids, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask + recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): worker_ctx = {} worker_ctx["recording"] = recording @@ -545,13 +546,13 @@ def _init_worker_distribute_single_buffer( worker_ctx["mode"] = mode if mode == "memmap": - filename = wf_array_info + filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") worker_ctx["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - shm, shm_name, dtype, shape = wf_array_info + shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) worker_ctx["shm"] = shm @@ -587,16 +588,10 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + # slice in absolut in spikes vector l0 = i0 + s0 @@ -611,17 +606,17 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled ) - for spike_ind in range(l0, l1): - sample_index = spikes[spike_ind]["sample_index"] - unit_index = spikes[spike_ind]["unit_index"] + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] if sparsity_mask is None: - all_waveforms[spike_ind, :, :] = wf + all_waveforms[spike_index, :, :] = wf else: mask = sparsity_mask[unit_index, :] wf = wf[:, mask] - all_waveforms[spike_ind, :, : wf.shape[1]] = wf + all_waveforms[spike_index, :, : wf.shape[1]] = wf if worker_ctx["mode"] == "memmap": all_waveforms.flush() @@ -642,7 +637,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None sparsity_mask : None or numpy array Optionally the boolean sparsity mask folder : None or str or Path - If a folde ri sgiven all + If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming. Returns ------- From e37b0515742b129984eb75da35c869a1de6b78d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Sep 2023 08:13:32 +0000 Subject: [PATCH 32/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/waveform_tools.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 6e0d6f412b..39623da329 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -354,7 +354,6 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) - # slice in absolut in spikes vector l0 = i0 + s0 l1 = i1 + s0 @@ -416,7 +415,7 @@ def extract_waveforms_to_single_buffer( Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. This ensures that spikes.shape[0] == all_waveforms.shape[0]. - + Important note: for the "shared_memory" mode wf_array_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. @@ -481,7 +480,7 @@ def extract_waveforms_to_single_buffer( if mode == "memmap": all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) - # wf_array_info = str(file_path) + # wf_array_info = str(file_path) wf_array_info = dict(filename=str(file_path)) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: @@ -592,7 +591,6 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) - # slice in absolut in spikes vector l0 = i0 + s0 l1 = i1 + s0 From 966d56a9fa150472e43f888ad4e6b62f89a77ef6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 08:18:51 +0200 Subject: [PATCH 33/35] doc --- src/spikeinterface/core/waveform_tools.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 39623da329..da8e3d64b6 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -417,9 +417,11 @@ def extract_waveforms_to_single_buffer( This ensures that spikes.shape[0] == all_waveforms.shape[0]. Important note: for the "shared_memory" mode wf_array_info contains reference to - the shared memmory buffer, this variable must be reference as long as arrays as used. - And this variable is also returned. - To avoid this a copy to non shared memmory can be perform at the end. + the shared memmory buffer, this variable must be referenced as long as arrays is used. + This variable must also unlink() when the array is de-referenced. + To avoid this complicated behavior, by default (copy=True) the shared memmory buffer is copied into a standard + numpy array. + Parameters ---------- From e73cf7e107026d80b176e9fb420b31cea964b730 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 10:27:39 +0200 Subject: [PATCH 34/35] Simplify plexon2 tests (only run when dependencies are installed) --- .github/workflows/full-test-with-codecov.yml | 2 +- .github/workflows/full-test.yml | 5 +---- src/spikeinterface/extractors/tests/test_neoextractors.py | 5 +---- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index d0bf109a00..a5561c2ffc 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -54,7 +54,7 @@ jobs: - name: run tests run: | source ${{ github.workspace }}/test_env/bin/activate - pytest -m "not sorters_external and not plexon2" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 + pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1 echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY python ./.github/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY cat $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index a343500c08..8f88e84039 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -133,7 +133,7 @@ jobs: run: ./.github/run_tests.sh core - name: Test extractors if: ${{ steps.modules-changed.outputs.EXTRACTORS_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} - run: ./.github/run_tests.sh "extractors and not streaming_extractors and not plexon2" + run: ./.github/run_tests.sh "extractors and not streaming_extractors" - name: Test preprocessing if: ${{ steps.modules-changed.outputs.PREPROCESSING_CHANGED == 'true' || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh preprocessing @@ -164,6 +164,3 @@ jobs: - name: Test internal sorters if: ${{ steps.modules-changed.outputs.SORTERS_INTERNAL_CHANGED == 'true' || steps.modules-changed.outputs.SORTINGCOMPONENTS_CHANGED || steps.modules-changed.outputs.CORE_CHANGED == 'true' }} run: ./.github/run_tests.sh sorters_internal - - name: Test plexon2 - if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} - run: ./.github/run_tests.sh plexon2 diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 5fe42b0c4e..ce2703d382 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -323,9 +323,8 @@ def test_pickling(self): pass -# We mark plexon2 tests as they require additional dependencies (wine) +# We run plexon2 tests only if we have dependencies (wine) @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") -@pytest.mark.plexon2 class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2RecordingExtractor downloads = ["plexon"] @@ -335,7 +334,6 @@ class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") -@pytest.mark.plexon2 class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2EventExtractor downloads = ["plexon"] @@ -345,7 +343,6 @@ class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): @pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") -@pytest.mark.plexon2 class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = Plexon2SortingExtractor downloads = ["plexon"] From e3b96d3f4aa8cfa63b4703726950e81d86aa43df Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Sep 2023 10:52:19 +0200 Subject: [PATCH 35/35] Relax check_borders (and fix typo) in InjectTemplatesRecording --- src/spikeinterface/core/generate.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index bbf77682ee..d78a2e4e57 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1,5 +1,5 @@ import math - +import warnings import numpy as np from typing import Union, Optional, List, Literal @@ -1037,13 +1037,14 @@ def __init__( parent_recording: Union[BaseRecording, None] = None, num_samples: Optional[List[int]] = None, upsample_vector: Union[List[int], None] = None, - check_borbers: bool = True, + check_borders: bool = False, ) -> None: templates = np.asarray(templates) - if check_borbers: + # TODO: this should be external to this class. It is not the responsability of this class to check the templates + if check_borders: self._check_templates(templates) - # lets test this only once so force check_borbers=false for kwargs - check_borbers = False + # lets test this only once so force check_borders=False for kwargs + check_borders = False self.templates = templates channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2])) @@ -1131,7 +1132,7 @@ def __init__( "nbefore": nbefore, "amplitude_factor": amplitude_factor, "upsample_vector": upsample_vector, - "check_borbers": check_borbers, + "check_borders": check_borders, } if parent_recording is None: self._kwargs["num_samples"] = num_samples @@ -1144,8 +1145,8 @@ def _check_templates(templates: np.ndarray): threshold = 0.01 * max_value if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold: - raise Exception( - "Warning!\nYour templates do not go to 0 on the edges in InjectTemplatesRecording.__init__\nPlease make your window bigger." + warnings.warn( + "Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger." )