diff --git a/.github/actions/install-wine/action.yml b/.github/actions/install-wine/action.yml new file mode 100644 index 0000000000..3ae08ecd34 --- /dev/null +++ b/.github/actions/install-wine/action.yml @@ -0,0 +1,21 @@ +name: Install packages +description: This action installs the package and its dependencies for testing + +inputs: + python-version: + description: 'Python version to set up' + required: false + os: + description: 'Operating system to set up' + required: false + +runs: + using: "composite" + steps: + - name: Install wine (needed for Plexon2) + run: | + sudo rm -f /etc/apt/sources.list.d/microsoft-prod.list + sudo dpkg --add-architecture i386 + sudo apt-get update -qq + sudo apt-get install -yqq --allow-downgrades libc6:i386 libgcc-s1:i386 libstdc++6:i386 wine + shell: bash diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index ac5130bade..8f88e84039 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -75,6 +75,10 @@ jobs: echo "Extractors changed" echo "EXTRACTORS_CHANGED=true" >> $GITHUB_OUTPUT fi + if [[ $file == *"plexon2"* ]]; then + echo "Plexon2 changed" + echo "PLEXON2_CHANGED=true" >> $GITHUB_OUTPUT + fi if [[ $file == *"/preprocessing/"* ]]; then echo "Preprocessing changed" echo "PREPROCESSING_CHANGED=true" >> $GITHUB_OUTPUT @@ -122,6 +126,9 @@ jobs: done - name: Set execute permissions on run_tests.sh run: chmod +x .github/run_tests.sh + - name: Install Wine (Plexon2) + if: ${{ steps.modules-changed.outputs.PLEXON2_CHANGED == 'true' }} + uses: ./.github/actions/install-wine - name: Test core run: ./.github/run_tests.sh core - name: Test extractors diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ced1ee6a2f..07601cd208 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black files: ^src/ diff --git a/pyproject.toml b/pyproject.toml index e17d6f6506..474cdc483f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ extractors = [ "ONE-api>=1.19.1", "ibllib>=2.21.0", "pymatreader>=0.0.32", # For cell explorer matlab files + "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index bbf77682ee..d78a2e4e57 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1,5 +1,5 @@ import math - +import warnings import numpy as np from typing import Union, Optional, List, Literal @@ -1037,13 +1037,14 @@ def __init__( parent_recording: Union[BaseRecording, None] = None, num_samples: Optional[List[int]] = None, upsample_vector: Union[List[int], None] = None, - check_borbers: bool = True, + check_borders: bool = False, ) -> None: templates = np.asarray(templates) - if check_borbers: + # TODO: this should be external to this class. It is not the responsability of this class to check the templates + if check_borders: self._check_templates(templates) - # lets test this only once so force check_borbers=false for kwargs - check_borbers = False + # lets test this only once so force check_borders=False for kwargs + check_borders = False self.templates = templates channel_ids = parent_recording.channel_ids if parent_recording is not None else list(range(templates.shape[2])) @@ -1131,7 +1132,7 @@ def __init__( "nbefore": nbefore, "amplitude_factor": amplitude_factor, "upsample_vector": upsample_vector, - "check_borbers": check_borbers, + "check_borders": check_borders, } if parent_recording is None: self._kwargs["num_samples"] = num_samples @@ -1144,8 +1145,8 @@ def _check_templates(templates: np.ndarray): threshold = 0.01 * max_value if max(np.max(np.abs(templates[:, 0])), np.max(np.abs(templates[:, -1]))) > threshold: - raise Exception( - "Warning!\nYour templates do not go to 0 on the edges in InjectTemplatesRecording.__init__\nPlease make your window bigger." + warnings.warn( + "Warning! Your templates do not go to 0 on the edges in InjectTemplatesRecording. Please make your window bigger." ) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 8da47b1940..e9cf1bfb5f 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -8,7 +8,9 @@ from spikeinterface.core import generate_recording, generate_sorting from spikeinterface.core.waveform_tools import ( extract_waveforms_to_buffers, -) # allocate_waveforms_buffers, distribute_waveforms_to_buffers + extract_waveforms_to_single_buffer, + split_waveforms_by_units, +) if hasattr(pytest, "global_test_folder"): @@ -52,96 +54,95 @@ def test_waveform_tools(): unit_ids = sorting.unit_ids some_job_kwargs = [ - {}, {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True}, {"n_jobs": 2, "chunk_size": 3000, "progress_bar": True}, ] + some_modes = [ + {"mode": "memmap"}, + {"mode": "shared_memory"}, + ] + # if platform.system() != "Windows": + # # shared memory on windows is buggy... + # some_modes.append( + # { + # "mode": "shared_memory", + # } + # ) + + some_sparsity = [ + dict(sparsity_mask=None), + dict(sparsity_mask=np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool")), + ] # memmap mode - list_wfs = [] + list_wfs_dense = [] + list_wfs_sparse = [] for j, job_kwargs in enumerate(some_job_kwargs): - wf_folder = cache_folder / f"test_waveform_tools_{j}" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir(parents=True) - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=None, - copy=False, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - _check_all_wf_equal(list_wfs) - - # memory - if platform.system() != "Windows": - # shared memory on windows is buggy... - list_wfs = [] - for job_kwargs in some_job_kwargs: - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='shared_memory', folder=None, dtype=dtype) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, mode='shared_memory', **job_kwargs) - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="shared_memory", - return_scaled=False, - folder=None, - dtype=dtype, - sparsity_mask=None, - copy=True, - **job_kwargs, - ) - for unit_ind, unit_id in enumerate(unit_ids): - wf = wfs_arrays[unit_id] - assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) - list_wfs.append({unit_id: wfs_arrays[unit_id].copy() for unit_id in unit_ids}) - # to avoid warning we need to first destroy arrays then sharedmemm object - # del wfs_arrays - # del wfs_arrays_info - _check_all_wf_equal(list_wfs) - - # with sparsity - wf_folder = cache_folder / "test_waveform_tools_sparse" - if wf_folder.is_dir(): - shutil.rmtree(wf_folder) - wf_folder.mkdir() - - sparsity_mask = np.random.randint(0, 2, size=(unit_ids.size, recording.channel_ids.size), dtype="bool") - job_kwargs = {"n_jobs": 1, "chunk_size": 3000, "progress_bar": True} - - # wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers(recording, spikes, unit_ids, nbefore, nafter, mode='memmap', folder=wf_folder, dtype=dtype, sparsity_mask=sparsity_mask) - # distribute_waveforms_to_buffers(recording, spikes, unit_ids, wfs_arrays_info, nbefore, nafter, return_scaled, sparsity_mask=sparsity_mask, **job_kwargs) - - wfs_arrays = extract_waveforms_to_buffers( - recording, - spikes, - unit_ids, - nbefore, - nafter, - mode="memmap", - return_scaled=False, - folder=wf_folder, - dtype=dtype, - sparsity_mask=sparsity_mask, - copy=False, - **job_kwargs, - ) + for k, mode_kwargs in enumerate(some_modes): + for l, sparsity_kwargs in enumerate(some_sparsity): + # print() + # print(job_kwargs, mode_kwargs, 'sparse=', sparsity_kwargs['sparsity_mask'] is None) + + if mode_kwargs["mode"] == "memmap": + wf_folder = cache_folder / f"test_waveform_tools_{j}_{k}_{l}" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + wf_folder.mkdir(parents=True) + wf_file_path = wf_folder / "waveforms_all_units.npy" + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["folder"] = wf_folder + + wfs_arrays = extract_waveforms_to_buffers( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + for unit_ind, unit_id in enumerate(unit_ids): + wf = wfs_arrays[unit_id] + assert wf.shape[0] == np.sum(spikes["unit_index"] == unit_ind) + + if sparsity_kwargs["sparsity_mask"] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["file_path"] = wf_file_path + + all_waveforms = extract_waveforms_to_single_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + return_scaled=False, + dtype=dtype, + copy=True, + **sparsity_kwargs, + **mode_kwargs_, + **job_kwargs, + ) + wfs_arrays = split_waveforms_by_units( + unit_ids, spikes, all_waveforms, sparsity_mask=sparsity_kwargs["sparsity_mask"] + ) + if sparsity_kwargs["sparsity_mask"] is None: + list_wfs_dense.append(wfs_arrays) + else: + list_wfs_sparse.append(wfs_arrays) + + _check_all_wf_equal(list_wfs_dense) + _check_all_wf_equal(list_wfs_sparse) if __name__ == "__main__": diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index a10c209f47..da8e3d64b6 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -36,7 +36,7 @@ def extract_waveforms_to_buffers( Same as calling allocate_waveforms_buffers() and then distribute_waveforms_to_buffers(). - Important note: for the "shared_memory" mode wfs_arrays_info contains reference to + Important note: for the "shared_memory" mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. And this variable is also returned. To avoid this a copy to non shared memmory can be perform at the end. @@ -66,17 +66,17 @@ def extract_waveforms_to_buffers( If not None shape must be must be (len(unit_ids), len(channel_ids)) copy: bool If True (default), the output shared memory object is copied to a numpy standard array. - If copy=False then wfs_arrays_info is also return. Please keep in mind that wfs_arrays_info - need to be referenced as long as wfs_arrays will be used otherwise it will be very hard to debug. + If copy=False then arrays_info is also return. Please keep in mind that arrays_info + need to be referenced as long as waveforms_by_units will be used otherwise it will be very hard to debug. Also when copy=False the SharedMemory will need to be unlink manually {} Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep) - wfs_arrays_info: dict of info + arrays_info: dict of info Optionally return in case of shared_memory if copy=False. Dictionary to "construct" array in workers process (memmap file or sharemem info) """ @@ -89,7 +89,7 @@ def extract_waveforms_to_buffers( dtype = "float32" dtype = np.dtype(dtype) - wfs_arrays, wfs_arrays_info = allocate_waveforms_buffers( + waveforms_by_units, arrays_info = allocate_waveforms_buffers( recording, spikes, unit_ids, nbefore, nafter, mode=mode, folder=folder, dtype=dtype, sparsity_mask=sparsity_mask ) @@ -97,7 +97,7 @@ def extract_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -107,19 +107,19 @@ def extract_waveforms_to_buffers( ) if mode == "memmap": - return wfs_arrays + return waveforms_by_units elif mode == "shared_memory": if copy: - wfs_arrays = {unit_id: arr.copy() for unit_id, arr in wfs_arrays.items()} + waveforms_by_units = {unit_id: arr.copy() for unit_id, arr in waveforms_by_units.items()} # release all sharedmem buffer for unit_id in unit_ids: - shm = wfs_arrays_info[unit_id][0] + shm = arrays_info[unit_id][0] if shm is not None: # empty array have None shm.unlink() - return wfs_arrays + return waveforms_by_units else: - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info extract_waveforms_to_buffers.__doc__ = extract_waveforms_to_buffers.__doc__.format(_shared_job_kwargs_doc) @@ -131,7 +131,7 @@ def allocate_waveforms_buffers( """ Allocate memmap or shared memory buffers before snippet extraction. - Important note: for the shared memory mode wfs_arrays_info contains reference to + Important note: for the shared memory mode arrays_info contains reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -158,9 +158,9 @@ def allocate_waveforms_buffers( Returns ------- - wfs_arrays: dict of arrays + waveforms_by_units: dict of arrays Arrays for all units (memmap or shared_memmep - wfs_arrays_info: dict of info + arrays_info: dict of info Dictionary to "construct" array in workers process (memmap file or sharemem) """ @@ -173,8 +173,8 @@ def allocate_waveforms_buffers( folder = Path(folder) # prepare buffers - wfs_arrays = {} - wfs_arrays_info = {} + waveforms_by_units = {} + arrays_info = {} for unit_ind, unit_id in enumerate(unit_ids): n_spikes = np.sum(spikes["unit_index"] == unit_ind) if sparsity_mask is None: @@ -186,34 +186,35 @@ def allocate_waveforms_buffers( if mode == "memmap": filename = str(folder / f"waveforms_{unit_id}.npy") arr = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = filename + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = filename elif mode == "shared_memory": - if n_spikes == 0: + if n_spikes == 0 or num_chans == 0: arr = np.zeros(shape, dtype=dtype) shm = None shm_name = None else: arr, shm = make_shared_array(shape, dtype) shm_name = shm.name - wfs_arrays[unit_id] = arr - wfs_arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) + waveforms_by_units[unit_id] = arr + arrays_info[unit_id] = (shm, shm_name, dtype.str, shape) else: raise ValueError("allocate_waveforms_buffers bad mode") - return wfs_arrays, wfs_arrays_info + return waveforms_by_units, arrays_info def distribute_waveforms_to_buffers( recording, spikes, unit_ids, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, mode="memmap", sparsity_mask=None, + job_name=None, **job_kwargs, ): """ @@ -221,7 +222,7 @@ def distribute_waveforms_to_buffers( Buffers must be pre-allocated with the `allocate_waveforms_buffers()` function. - Important note, for "shared_memory" mode wfs_arrays_info contain reference to + Important note, for "shared_memory" mode arrays_info contain reference to the shared memmory buffer, this variable must be reference as long as arrays as used. Parameters @@ -233,7 +234,7 @@ def distribute_waveforms_to_buffers( This vector can be spikes = Sorting.to_spike_vector() unit_ids: list ot numpy List of unit_ids - wfs_arrays_info: dict + arrays_info: dict Dictionary to "construct" array in workers process (memmap file or sharemem) nbefore: int N samples before spike @@ -257,14 +258,14 @@ def distribute_waveforms_to_buffers( inds_by_unit[unit_id] = inds # and run - func = _waveform_extractor_chunk - init_func = _init_worker_waveform_extractor + func = _worker_distribute_buffers + init_func = _init_worker_distribute_buffers init_args = ( recording, unit_ids, spikes, - wfs_arrays_info, + arrays_info, nbefore, nafter, return_scaled, @@ -272,9 +273,9 @@ def distribute_waveforms_to_buffers( mode, sparsity_mask, ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=f"extract waveforms {mode}", **job_kwargs - ) + if job_name is None: + job_name = f"extract waveforms {mode} multi buffer" + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) processor.run() @@ -282,8 +283,8 @@ def distribute_waveforms_to_buffers( # used by ChunkRecordingExecutor -def _init_worker_waveform_extractor( - recording, unit_ids, spikes, wfs_arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask +def _init_worker_distribute_buffers( + recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker worker_ctx = {} @@ -296,23 +297,23 @@ def _init_worker_waveform_extractor( if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["wfs_arrays_info"] = wfs_arrays_info + worker_ctx["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory - wfs_arrays = {} + waveforms_by_units = {} shms = {} - for unit_id, (shm, shm_name, dtype, shape) in wfs_arrays_info.items(): + for unit_id, (shm, shm_name, dtype, shape) in arrays_info.items(): if shm_name is None: arr = np.zeros(shape=shape, dtype=dtype) else: shm = SharedMemory(shm_name) arr = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - wfs_arrays[unit_id] = arr + waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm worker_ctx["shms"] = shms - worker_ctx["wfs_arrays"] = wfs_arrays + worker_ctx["waveforms_by_units"] = waveforms_by_units worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes @@ -328,7 +329,7 @@ def _init_worker_waveform_extractor( # used by ChunkRecordingExecutor -def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] unit_ids = worker_ctx["unit_ids"] @@ -349,16 +350,9 @@ def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx) # take only spikes in range [start_frame, end_frame] # this is a slice so no copy!! - i0 = np.searchsorted(in_seg_spikes["sample_index"], start_frame) - i1 = np.searchsorted(in_seg_spikes["sample_index"], end_frame) - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # useful only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (in_seg_spikes[i0]["sample_index"] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (in_seg_spikes[i1 - 1]["sample_index"] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) # slice in absolut in spikes vector l0 = i0 + s0 @@ -382,10 +376,10 @@ def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx) if worker_ctx["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["wfs_arrays_info"][unit_id] + filename = worker_ctx["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["wfs_arrays"][unit_id] + wfs = worker_ctx["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -397,6 +391,282 @@ def _waveform_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx) wfs[pos, :, :] = wf[:, sparsity_mask[unit_ind]] +def extract_waveforms_to_single_buffer( + recording, + spikes, + unit_ids, + nbefore, + nafter, + mode="memmap", + return_scaled=False, + file_path=None, + dtype=None, + sparsity_mask=None, + copy=False, + job_name=None, + **job_kwargs, +): + """ + Allocate a single buffer (memmap or or shared memory) and then distribute every waveform into it. + + Contrary to extract_waveforms_to_buffers() all waveforms are extracted in the same buffer, so the spike vector is + needed to recover waveforms unit by unit. Importantly in case of sparsity, the channels are not aligned across + units. + + Note: spikes near borders (nbefore/nafter) are not extracted and 0 are put the output buffer. + This ensures that spikes.shape[0] == all_waveforms.shape[0]. + + Important note: for the "shared_memory" mode wf_array_info contains reference to + the shared memmory buffer, this variable must be referenced as long as arrays is used. + This variable must also unlink() when the array is de-referenced. + To avoid this complicated behavior, by default (copy=True) the shared memmory buffer is copied into a standard + numpy array. + + + Parameters + ---------- + recording: recording + The recording object + spikes: 1d numpy array with several fields + Spikes handled as a unique vector. + This vector can be obtained with: `spikes = Sorting.to_spike_vector()` + unit_ids: list ot numpy + List of unit_ids + nbefore: int + N samples before spike + nafter: int + N samples after spike + mode: str + Mode to use ('memmap' | 'shared_memory') + return_scaled: bool + Scale traces before exporting to buffer or not. + file_path: str or path + In case of memmap mode, file to save npy file. + dtype: numpy.dtype + dtype for waveforms buffer + sparsity_mask: None or array of bool + If not None shape must be must be (len(unit_ids), len(channel_ids)) + copy: bool + If True (default), the output shared memory object is copied to a numpy standard array and no reference + to the internal shared memory object is kept. + If copy=False then the shared memory object is also returned. Please keep in mind that the shared memory object + need to be referenced as long as all_waveforms will be used otherwise it might produce segmentation + faults which are hard to debug. + Also when copy=False the SharedMemory will need to be unlink manually if proper cleanup of resources is desired. + + {} + + Returns + ------- + all_waveforms: numpy array + Single array with shape (nump_spikes, num_samples, num_channels) + + wf_array_info: dict of info + Optionally return in case of shared_memory if copy=False. + Dictionary to "construct" array in workers process (memmap file or sharemem info) + """ + nsamples = nbefore + nafter + + dtype = np.dtype(dtype) + if mode == "shared_memory": + assert file_path is None + else: + file_path = Path(file_path) + + num_spikes = spikes.size + if sparsity_mask is None: + num_chans = recording.get_num_channels() + else: + num_chans = max(np.sum(sparsity_mask, axis=1)) + shape = (num_spikes, nsamples, num_chans) + + if mode == "memmap": + all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) + # wf_array_info = str(file_path) + wf_array_info = dict(filename=str(file_path)) + elif mode == "shared_memory": + if num_spikes == 0 or num_chans == 0: + all_waveforms = np.zeros(shape, dtype=dtype) + shm = None + shm_name = None + else: + all_waveforms, shm = make_shared_array(shape, dtype) + shm_name = shm.name + # wf_array_info = (shm, shm_name, dtype.str, shape) + wf_array_info = dict(shm=shm, shm_name=shm_name, dtype=dtype.str, shape=shape) + else: + raise ValueError("allocate_waveforms_buffers bad mode") + + job_kwargs = fix_job_kwargs(job_kwargs) + + if num_spikes > 0: + # and run + func = _worker_distribute_single_buffer + init_func = _init_worker_distribute_single_buffer + + init_args = ( + recording, + spikes, + wf_array_info, + nbefore, + nafter, + return_scaled, + mode, + sparsity_mask, + ) + if job_name is None: + job_name = f"extract waveforms {mode} mono buffer" + + processor = ChunkRecordingExecutor(recording, func, init_func, init_args, job_name=job_name, **job_kwargs) + processor.run() + + if mode == "memmap": + return all_waveforms + elif mode == "shared_memory": + if copy: + if shm is not None: + # release all sharedmem buffer + # empty array have None + shm.unlink() + return all_waveforms.copy() + else: + return all_waveforms, wf_array_info + + +def _init_worker_distribute_single_buffer( + recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask +): + worker_ctx = {} + worker_ctx["recording"] = recording + worker_ctx["wf_array_info"] = wf_array_info + worker_ctx["spikes"] = spikes + worker_ctx["nbefore"] = nbefore + worker_ctx["nafter"] = nafter + worker_ctx["return_scaled"] = return_scaled + worker_ctx["sparsity_mask"] = sparsity_mask + worker_ctx["mode"] = mode + + if mode == "memmap": + filename = wf_array_info["filename"] + all_waveforms = np.load(str(filename), mmap_mode="r+") + worker_ctx["all_waveforms"] = all_waveforms + elif mode == "shared_memory": + from multiprocessing.shared_memory import SharedMemory + + shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] + shm = SharedMemory(shm_name) + all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) + worker_ctx["shm"] = shm + worker_ctx["all_waveforms"] = all_waveforms + + # prepare segment slices + segment_slices = [] + for segment_index in range(recording.get_num_segments()): + s0 = np.searchsorted(spikes["segment_index"], segment_index) + s1 = np.searchsorted(spikes["segment_index"], segment_index + 1) + segment_slices.append((s0, s1)) + worker_ctx["segment_slices"] = segment_slices + + return worker_ctx + + +# used by ChunkRecordingExecutor +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): + # recover variables of the worker + recording = worker_ctx["recording"] + segment_slices = worker_ctx["segment_slices"] + spikes = worker_ctx["spikes"] + nbefore = worker_ctx["nbefore"] + nafter = worker_ctx["nafter"] + return_scaled = worker_ctx["return_scaled"] + sparsity_mask = worker_ctx["sparsity_mask"] + all_waveforms = worker_ctx["all_waveforms"] + + seg_size = recording.get_num_samples(segment_index=segment_index) + + s0, s1 = segment_slices[segment_index] + in_seg_spikes = spikes[s0:s1] + + # take only spikes in range [start_frame, end_frame] + # this is a slice so no copy!! + # the border of segment are protected by nbefore on left an nafter on the right + i0 = np.searchsorted(in_seg_spikes["sample_index"], max(start_frame, nbefore)) + i1 = np.searchsorted(in_seg_spikes["sample_index"], min(end_frame, seg_size - nafter)) + + # slice in absolut in spikes vector + l0 = i0 + s0 + l1 = i1 + s0 + + if l1 > l0: + start = spikes[l0]["sample_index"] - nbefore + end = spikes[l1 - 1]["sample_index"] + nafter + + # load trace in memory + traces = recording.get_traces( + start_frame=start, end_frame=end, segment_index=segment_index, return_scaled=return_scaled + ) + + for spike_index in range(l0, l1): + sample_index = spikes[spike_index]["sample_index"] + unit_index = spikes[spike_index]["unit_index"] + wf = traces[sample_index - start - nbefore : sample_index - start + nafter, :] + + if sparsity_mask is None: + all_waveforms[spike_index, :, :] = wf + else: + mask = sparsity_mask[unit_index, :] + wf = wf[:, mask] + all_waveforms[spike_index, :, : wf.shape[1]] = wf + + if worker_ctx["mode"] == "memmap": + all_waveforms.flush() + + +def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None, folder=None): + """ + Split a single buffer waveforms into waveforms by units (multi buffers or multi files). + + Parameters + ---------- + unit_ids: list or numpy array + List of unit ids + spikes: numpy array + The spike vector + all_waveforms : numpy array + Single buffer containing all waveforms + sparsity_mask : None or numpy array + Optionally the boolean sparsity mask + folder : None or str or Path + If a folder is given all waveforms by units are copied in a npy file using f"waveforms_{unit_id}.npy" naming. + + Returns + ------- + waveforms_by_units: dict of array + A dict of arrays. + In case of folder not None, this contain the memmap of the files. + """ + if folder is not None: + folder = Path(folder) + waveforms_by_units = {} + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes["unit_index"] == unit_index + if sparsity_mask is not None: + chan_mask = sparsity_mask[unit_index, :] + num_chans = np.sum(chan_mask) + wfs = all_waveforms[mask, :, :][:, :, :num_chans] + else: + wfs = all_waveforms[mask, :, :] + + if folder is None: + waveforms_by_units[unit_id] = wfs + else: + np.save(folder / f"waveforms_{unit_id}.npy", wfs) + # this avoid keeping in memory all waveforms + waveforms_by_units[unit_id] = np.load(f"waveforms_{unit_id}.npy", mmap_mode="r") + + return waveforms_by_units + + def has_exceeding_spikes(recording, sorting): """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index 0d9da1960a..3360b76147 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -25,6 +25,14 @@ read_openephys_event, ) from .plexon import PlexonRecordingExtractor, PlexonSortingExtractor, read_plexon, read_plexon_sorting +from .plexon2 import ( + Plexon2SortingExtractor, + Plexon2RecordingExtractor, + Plexon2EventExtractor, + read_plexon2, + read_plexon2_sorting, + read_plexon2_event, +) from .spike2 import Spike2RecordingExtractor, read_spike2 from .spikegadgets import SpikeGadgetsRecordingExtractor, read_spikegadgets from .spikeglx import SpikeGLXRecordingExtractor, read_spikeglx @@ -49,12 +57,18 @@ OpenEphysBinaryRecordingExtractor, OpenEphysLegacyRecordingExtractor, PlexonRecordingExtractor, + Plexon2RecordingExtractor, Spike2RecordingExtractor, SpikeGadgetsRecordingExtractor, SpikeGLXRecordingExtractor, TdtRecordingExtractor, ] -neo_sorting_extractors_list = [BlackrockSortingExtractor, MEArecSortingExtractor, NeuralynxSortingExtractor] +neo_sorting_extractors_list = [ + BlackrockSortingExtractor, + MEArecSortingExtractor, + NeuralynxSortingExtractor, + Plexon2SortingExtractor, +] -neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor] +neo_event_extractors_list = [AlphaOmegaEventExtractor, OpenEphysBinaryEventExtractor, Plexon2EventExtractor] diff --git a/src/spikeinterface/extractors/neoextractors/plexon2.py b/src/spikeinterface/extractors/neoextractors/plexon2.py new file mode 100644 index 0000000000..8dbfc67e90 --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/plexon2.py @@ -0,0 +1,103 @@ +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor, NeoBaseEventExtractor + + +class Plexon2RecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading plexon pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + stream_id: str, optional + If there are several streams, specify the stream id you want to load. + stream_name: str, optional + If there are several streams, specify the stream name you want to load. + all_annotations: bool, default: False + Load exhaustively all annotations from neo. + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, file_path, stream_id=None, stream_name=None, all_annotations=False): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, stream_id=stream_id, stream_name=stream_name, all_annotations=all_annotations, **neo_kwargs + ) + self._kwargs.update({"file_path": str(file_path)}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2SortingExtractor(NeoBaseSortingExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + file_path: str + The file path to load the recordings from. + sampling_frequency: float, default: None + The sampling frequency of the sorting (required for multiple streams with different sampling frequencies). + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + neo_returns_frames = True + name = "plexon2" + + def __init__(self, file_path, sampling_frequency=None): + from neo.rawio import Plexon2RawIO + + neo_kwargs = self.map_to_neo_kwargs(file_path) + neo_reader = Plexon2RawIO(**neo_kwargs) + neo_reader.parse_header() + NeoBaseSortingExtractor.__init__(self, sampling_frequency=sampling_frequency, **neo_kwargs) + self._kwargs.update({"file_path": str(file_path), "sampling_frequency": sampling_frequency}) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +class Plexon2EventExtractor(NeoBaseEventExtractor): + """ + Class for reading plexon spiking data from .pl2 files. + + Based on :py:class:`neo.rawio.Plexon2RawIO` + + Parameters + ---------- + folder_path: str + + """ + + mode = "file" + NeoRawIOClass = "Plexon2RawIO" + name = "plexon2" + + def __init__(self, folder_path, block_index=None): + neo_kwargs = self.map_to_neo_kwargs(folder_path) + NeoBaseEventExtractor.__init__(self, block_index=block_index, **neo_kwargs) + + @classmethod + def map_to_neo_kwargs(cls, folder_path): + neo_kwargs = {"filename": str(folder_path)} + return neo_kwargs + + +read_plexon2 = define_function_from_class(source_class=Plexon2RecordingExtractor, name="read_plexon2") +read_plexon2_sorting = define_function_from_class(source_class=Plexon2SortingExtractor, name="read_plexon2_sorting") +read_plexon2_event = define_function_from_class(source_class=Plexon2EventExtractor, name="read_plexon2_event") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 900bdec06e..ce2703d382 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -1,5 +1,6 @@ import unittest -from platform import python_version +import platform +import subprocess from packaging import version import pytest @@ -18,6 +19,38 @@ local_folder = get_global_dataset_folder() / "ephy_testing_data" +def has_plexon2_dependencies(): + """ + Check if required Plexon2 dependencies are installed on different OS. + """ + + os_type = platform.system() + + if os_type == "Windows": + # On Windows, no need for additional dependencies + return True + + elif os_type == "Linux": + # Check for 'wine' using dpkg + try: + result_wine = subprocess.run( + ["dpkg", "-l", "wine"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + except subprocess.CalledProcessError: + return False + + # Check for 'zugbruecke' using pip + try: + import zugbruecke + + return True + except ImportError: + return False + else: + # Not sure about MacOS + raise ValueError(f"Unsupported OS: {os_type}") + + class MearecRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MEArecRecordingExtractor downloads = ["mearec"] @@ -218,7 +251,7 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif( - version.parse(python_version()) >= version.parse("3.10"), + version.parse(platform.python_version()) >= version.parse("3.10"), reason="Sonpy only testing with Python < 3.10!", ) class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): @@ -290,6 +323,32 @@ def test_pickling(self): pass +# We run plexon2 tests only if we have dependencies (wine) +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +class Plexon2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2RecordingExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2", {"stream_id": "3"}), + ] + + +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +class Plexon2EventTest(EventCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2EventExtractor + downloads = ["plexon"] + entities = [ + ("plexon/4chDemoPL2.pl2"), + ] + + +@pytest.mark.skipif(not has_plexon2_dependencies(), reason="Required dependencies not installed") +class Plexon2SortingTest(SortingCommonTestSuite, unittest.TestCase): + ExtractorClass = Plexon2SortingExtractor + downloads = ["plexon"] + entities = [("plexon/4chDemoPL2.pl2", {"sampling_frequency": 40000})] + + if __name__ == "__main__": # test = MearecSortingTest() # test = SpikeGLXRecordingTest() @@ -304,7 +363,7 @@ def test_pickling(self): # test = PlexonRecordingTest() # test = PlexonSortingTest() # test = NeuralynxRecordingTest() - test = BlackrockRecordingTest() + test = Plexon2RecordingTest() # test = MCSRawRecordingTest() # test = KiloSortSortingTest() # test = Spike2RecordingTest() diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 3ebeafcfec..5a0148c5c4 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -7,6 +7,9 @@ from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +# DEBUG = True + + class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension): """ Computes amplitude scalings from WaveformExtractor. @@ -21,9 +24,25 @@ def __init__(self, waveform_extractor): self.spikes = self.waveform_extractor.sorting.to_spike_vector( extremum_channel_inds=extremum_channel_inds, use_cache=False ) - - def _set_params(self, sparsity, max_dense_channels, ms_before, ms_after): - params = dict(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + self.collisions = None + + def _set_params( + self, + sparsity, + max_dense_channels, + ms_before, + ms_after, + handle_collisions, + delta_collision_ms, + ): + params = dict( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + delta_collision_ms=delta_collision_ms, + ) return params def _select_extension_data(self, unit_ids): @@ -43,6 +62,11 @@ def _run(self, **job_kwargs): ms_before = self._params["ms_before"] ms_after = self._params["ms_after"] + # collisions + handle_collisions = self._params["handle_collisions"] + delta_collision_ms = self._params["delta_collision_ms"] + delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency) + return_scaled = we._params["return_scaled"] unit_ids = we.unit_ids @@ -67,6 +91,8 @@ def _run(self, **job_kwargs): assert recording.get_num_channels() <= self._params["max_dense_channels"], "" sparsity = ChannelSparsity.create_dense(we) sparsity_inds = sparsity.unit_id_to_channel_indices + + # easier to use in chunk function as spikes use unit_index instead o id unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)} all_templates = we.get_all_templates() @@ -93,6 +119,8 @@ def _run(self, **job_kwargs): cut_out_before, cut_out_after, return_scaled, + handle_collisions, + delta_collision_samples, ) processor = ChunkRecordingExecutor( recording, @@ -104,10 +132,18 @@ def _run(self, **job_kwargs): **job_kwargs, ) out = processor.run() - (amp_scalings,) = zip(*out) + (amp_scalings, collisions) = zip(*out) amp_scalings = np.concatenate(amp_scalings) - self._extension_data[f"amplitude_scalings"] = amp_scalings + collisions_dict = {} + if handle_collisions: + for collision in collisions: + collisions_dict.update(collision) + self.collisions = collisions_dict + # Note: collisions are note in _extension_data because they are not pickable. We only store the indices + self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) + + self._extension_data["amplitude_scalings"] = amp_scalings def get_data(self, outputs="concatenated"): """ @@ -154,6 +190,8 @@ def compute_amplitude_scalings( max_dense_channels=16, ms_before=None, ms_after=None, + handle_collisions=True, + delta_collision_ms=2, load_if_exists=False, outputs="concatenated", **job_kwargs, @@ -165,22 +203,27 @@ def compute_amplitude_scalings( ---------- waveform_extractor: WaveformExtractor The waveform extractor object - sparsity: ChannelSparsity + sparsity: ChannelSparsity, default: None If waveforms are not sparse, sparsity is required if the number of channels is greater than `max_dense_channels`. If the waveform extractor is sparse, its sparsity is automatically used. - By default None max_dense_channels: int, default: 16 Maximum number of channels to allow running without sparsity. To compute amplitude scaling using dense waveforms, set this to None, sparsity to None, and pass dense waveforms as input. - ms_before : float, optional + ms_before : float, default: None The cut out to apply before the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_before is used, by default None - ms_after : float, optional + If None, the WaveformExtractor ms_before is used. + ms_after : float, default: None The cut out to apply after the spike peak to extract local waveforms. - If None, the WaveformExtractor ms_after is used, by default None + If None, the WaveformExtractor ms_after is used. + handle_collisions: bool, default: True + Whether to handle collisions between spikes. If True, the amplitude scaling of colliding spikes + (defined as spikes within `delta_collision_ms` ms and with overlapping sparsity) is computed by fitting a + multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. + delta_collision_ms: float, default: 2 + The maximum time difference in ms before and after a spike to gather colliding spikes. load_if_exists : bool, default: False Whether to load precomputed spike amplitudes, if they already exist. - outputs: str + outputs: str, default: 'concatenated' How the output should be returned: - 'concatenated' - 'by_unit' @@ -197,7 +240,14 @@ def compute_amplitude_scalings( sac = waveform_extractor.load_extension(AmplitudeScalingsCalculator.extension_name) else: sac = AmplitudeScalingsCalculator(waveform_extractor) - sac.set_params(sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, ms_after=ms_after) + sac.set_params( + sparsity=sparsity, + max_dense_channels=max_dense_channels, + ms_before=ms_before, + ms_after=ms_after, + handle_collisions=handle_collisions, + delta_collision_ms=delta_collision_ms, + ) sac.run(**job_kwargs) amps = sac.get_data(outputs=outputs) @@ -218,6 +268,8 @@ def _init_worker_amplitude_scalings( cut_out_before, cut_out_after, return_scaled, + handle_collisions, + delta_collision_samples, ): # create a local dict per worker worker_ctx = {} @@ -229,14 +281,24 @@ def _init_worker_amplitude_scalings( worker_ctx["nafter"] = nafter worker_ctx["cut_out_before"] = cut_out_before worker_ctx["cut_out_after"] = cut_out_after - worker_ctx["margin"] = max(nbefore, nafter) worker_ctx["return_scaled"] = return_scaled worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices + worker_ctx["handle_collisions"] = handle_collisions + worker_ctx["delta_collision_samples"] = delta_collision_samples + + if not handle_collisions: + worker_ctx["margin"] = max(nbefore, nafter) + else: + # in this case we extend the margin to be able to get with collisions outside the chunk + margin_waveforms = max(nbefore, nafter) + max_margin_collisions = delta_collision_samples + margin_waveforms + worker_ctx["margin"] = max_margin_collisions return worker_ctx def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx): + # from sklearn.linear_model import LinearRegression from scipy.stats import linregress # recover variables of the worker @@ -250,16 +312,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) cut_out_after = worker_ctx["cut_out_after"] margin = worker_ctx["margin"] return_scaled = worker_ctx["return_scaled"] + handle_collisions = worker_ctx["handle_collisions"] + delta_collision_samples = worker_ctx["delta_collision_samples"] spikes_in_segment = spikes[segment_slices[segment_index]] i0 = np.searchsorted(spikes_in_segment["sample_index"], start_frame) i1 = np.searchsorted(spikes_in_segment["sample_index"], end_frame) - local_waveforms = [] - templates = [] - scalings = [] - if i0 != i1: local_spikes = spikes_in_segment[i0:i1] traces_with_margin, left, right = get_chunk_with_margin( @@ -272,8 +332,26 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) offsets = recording.get_property("offset_to_uV") traces_with_margin = traces_with_margin.astype("float32") * gains + offsets - # get all waveforms - for spike in local_spikes: + # set colliding spikes apart (if needed) + if handle_collisions: + # local spikes with margin! + i0_margin = np.searchsorted(spikes_in_segment["sample_index"], start_frame - left) + i1_margin = np.searchsorted(spikes_in_segment["sample_index"], end_frame + right) + local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin] + collisions_local = find_collisions( + local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices + ) + else: + collisions_local = {} + + # compute the scaling for each spike + scalings = np.zeros(len(local_spikes), dtype=float) + # collision_global transforms local spike index to global spike index + collisions_global = {} + for spike_index, spike in enumerate(local_spikes): + if spike_index in collisions_local.keys(): + # we deal with overlapping spikes later + continue unit_index = spike["unit_index"] sample_index = spike["sample_index"] sparse_indices = unit_inds_to_channel_indices[unit_index] @@ -291,10 +369,335 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx) else: local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices] assert template.shape == local_waveform.shape - local_waveforms.append(local_waveform) - templates.append(template) + + # here we use linregress, which is equivalent to using sklearn LinearRegression with fit_intercept=True + # y = local_waveform.flatten() + # X = template.flatten()[:, np.newaxis] + # reg = LinearRegression(positive=True, fit_intercept=True).fit(X, y) + # scalings[spike_index] = reg.coef_[0] linregress_res = linregress(template.flatten(), local_waveform.flatten()) - scalings.append(linregress_res[0]) - scalings = np.array(scalings) + scalings[spike_index] = linregress_res[0] + + # deal with collisions + if len(collisions_local) > 0: + num_spikes_in_previous_segments = int( + np.sum([len(spikes[segment_slices[s]]) for s in range(segment_index)]) + ) + for spike_index, collision in collisions_local.items(): + scaled_amps = fit_collision( + collision, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, + ) + # the scaling for the current spike is at index 0 + scalings[spike_index] = scaled_amps[0] + + # make collision_dict indices "absolute" by adding i0 and the cumulative number of spikes in previous segments + collisions_global.update({spike_index + i0 + num_spikes_in_previous_segments: collision}) + else: + scalings = np.array([]) + collisions_global = {} + + return (scalings, collisions_global) + + +### Collision handling ### +def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j): + """ + Returns True if the unit indices i and j are overlapping, False otherwise + + Parameters + ---------- + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + i: int + The first unit index + j: int + The second unit index + + Returns + ------- + bool + True if the unit indices i and j are overlapping, False otherwise + """ + if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0: + return True + else: + return False + - return (scalings,) +def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices): + """ + Finds the collisions between spikes. + + Parameters + ---------- + spikes: np.array + An array of spikes + spikes_w_margin: np.array + An array of spikes within the added margin + delta_collision_samples: int + The maximum number of samples between two spikes to consider them as overlapping + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices + + Returns + ------- + collision_spikes_dict: np.array + A dictionary with collisions. The key is the index of the spike with collision, the value is an + array of overlapping spikes, including the spike itself at position 0. + """ + # TODO: refactor to speed-up + collision_spikes_dict = {} + for spike_index, spike in enumerate(spikes): + # find the index of the spike within the spikes_w_margin + spike_index_w_margin = np.where(spikes_w_margin == spike)[0][0] + + # find the possible spikes per and post within delta_collision_samples + consecutive_window_pre = np.searchsorted( + spikes_w_margin["sample_index"], + spike["sample_index"] - delta_collision_samples, + ) + consecutive_window_post = np.searchsorted( + spikes_w_margin["sample_index"], + spike["sample_index"] + delta_collision_samples, + ) + # exclude the spike itself (it is included in the collision_spikes by construction) + pre_possible_consecutive_spike_indices = np.arange(consecutive_window_pre, spike_index_w_margin) + post_possible_consecutive_spike_indices = np.arange(spike_index_w_margin + 1, consecutive_window_post) + possible_overlapping_spike_indices = np.concatenate( + (pre_possible_consecutive_spike_indices, post_possible_consecutive_spike_indices) + ) + + # find the overlapping spikes in space as well + for possible_overlapping_spike_index in possible_overlapping_spike_indices: + if _are_unit_indices_overlapping( + unit_inds_to_channel_indices, + spike["unit_index"], + spikes_w_margin[possible_overlapping_spike_index]["unit_index"], + ): + if spike_index not in collision_spikes_dict: + collision_spikes_dict[spike_index] = np.array([spike]) + collision_spikes_dict[spike_index] = np.concatenate( + (collision_spikes_dict[spike_index], [spikes_w_margin[possible_overlapping_spike_index]]) + ) + return collision_spikes_dict + + +def fit_collision( + collision, + traces_with_margin, + start_frame, + end_frame, + left, + right, + nbefore, + all_templates, + unit_inds_to_channel_indices, + cut_out_before, + cut_out_after, +): + """ + Compute the best fit for a collision between a spike and its overlapping spikes. + The function first cuts out the traces around the spike and its overlapping spikes, then + fits a multi-linear regression model to the traces using the centered templates as predictors. + + Parameters + ---------- + collision: np.ndarray + A numpy array of shape (n_colliding_spikes, ) containing the colliding spikes (spike_dtype). + traces_with_margin: np.ndarray + A numpy array of shape (n_samples, n_channels) containing the traces with a margin. + start_frame: int + The start frame of the chunk for traces_with_margin. + end_frame: int + The end frame of the chunk for traces_with_margin. + left: int + The left margin of the chunk for traces_with_margin. + right: int + The right margin of the chunk for traces_with_margin. + nbefore: int + The number of samples before the spike to consider for the fit. + all_templates: np.ndarray + A numpy array of shape (n_units, n_samples, n_channels) containing the templates. + unit_inds_to_channel_indices: dict + A dictionary mapping unit indices to channel indices. + cut_out_before: int + The number of samples to cut out before the spike. + cut_out_after: int + The number of samples to cut out after the spike. + + Returns + ------- + np.ndarray + The fitted scaling factors for the colliding spikes. + """ + from sklearn.linear_model import LinearRegression + + # make center of the spike externally + sample_first_centered = np.min(collision["sample_index"]) - (start_frame - left) + sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left) + + # construct sparsity as union between units' sparsity + sparse_indices = np.array([], dtype="int") + for spike in collision: + sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]] + sparse_indices = np.union1d(sparse_indices, sparse_indices_i) + + local_waveform_start = max(0, sample_first_centered - cut_out_before) + local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after) + local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices] + + y = local_waveform.T.flatten() + X = np.zeros((len(y), len(collision))) + for i, spike in enumerate(collision): + full_template = np.zeros_like(local_waveform) + # center wrt cutout traces + sample_centered = spike["sample_index"] - (start_frame - left) - local_waveform_start + template = all_templates[spike["unit_index"]][:, sparse_indices] + template_cut = template[nbefore - cut_out_before : nbefore + cut_out_after] + # deal with borders + if sample_centered - cut_out_before < 0: + full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :] + elif sample_centered + cut_out_after > end_frame + right: + full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)] + else: + full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut + X[:, i] = full_template.T.flatten() + + reg = LinearRegression(fit_intercept=True, positive=True).fit(X, y) + scalings = reg.coef_ + return scalings + + +# uncomment for debugging +# def plot_collisions(we, sparsity=None, num_collisions=None): +# """ +# Plot the fitting of collision spikes. + +# Parameters +# ---------- +# we : WaveformExtractor +# The WaveformExtractor object. +# sparsity : ChannelSparsity, default=None +# The ChannelSparsity. If None, only main channels are plotted. +# num_collisions : int, default=None +# Number of collisions to plot. If None, all collisions are plotted. +# """ +# assert we.is_extension("amplitude_scalings"), "Could not find amplitude scalings extension!" +# sac = we.load_extension("amplitude_scalings") +# handle_collisions = sac._params["handle_collisions"] +# assert handle_collisions, "Amplitude scalings was run without handling collisions!" +# scalings = sac.get_data() + +# # overlapping_mask = sac.overlapping_mask +# # num_collisions = num_collisions or len(overlapping_mask) +# spikes = sac.spikes +# collisions = sac._extension_data[f"collisions"] +# collision_keys = list(collisions.keys()) +# num_collisions = num_collisions or len(collisions) +# num_collisions = min(num_collisions, len(collisions)) + +# for i in range(num_collisions): +# overlapping_spikes = collisions[collision_keys[i]] +# ax = plot_one_collision( +# we, collision_keys[i], overlapping_spikes, spikes, scalings=scalings, sparsity=sparsity +# ) + + +# def plot_one_collision( +# we, +# spike_index, +# overlapping_spikes, +# spikes, +# scalings=None, +# sparsity=None, +# cut_out_samples=100, +# ax=None +# ): +# import matplotlib.pyplot as plt + +# if ax is None: +# fig, ax = plt.subplots() + +# recording = we.recording +# nbefore_nafter_max = max(we.nafter, we.nbefore) +# cut_out_samples = max(cut_out_samples, nbefore_nafter_max) + +# if sparsity is not None: +# unit_inds_to_channel_indices = sparsity.unit_id_to_channel_indices +# sparse_indices = np.array([], dtype="int") +# for spike in overlapping_spikes: +# sparse_indices_i = unit_inds_to_channel_indices[we.unit_ids[spike["unit_index"]]] +# sparse_indices = np.union1d(sparse_indices, sparse_indices_i) +# else: +# sparse_indices = np.unique(overlapping_spikes["channel_index"]) + +# channel_ids = recording.channel_ids[sparse_indices] + +# center_spike = overlapping_spikes[0] +# max_delta = np.max( +# [ +# np.abs(center_spike["sample_index"] - np.min(overlapping_spikes[1:]["sample_index"])), +# np.abs(center_spike["sample_index"] - np.max(overlapping_spikes[1:]["sample_index"])), +# ] +# ) +# sf = max(0, center_spike["sample_index"] - max_delta - cut_out_samples) +# ef = min( +# center_spike["sample_index"] + max_delta + cut_out_samples, +# recording.get_num_samples(segment_index=center_spike["segment_index"]), +# ) +# tr_overlap = recording.get_traces(start_frame=sf, end_frame=ef, channel_ids=channel_ids, return_scaled=True) +# ts = np.arange(sf, ef) / recording.sampling_frequency * 1000 +# max_tr = np.max(np.abs(tr_overlap)) + +# for ch, tr in enumerate(tr_overlap.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="k") +# ax.text(ts[0], 1.2 * ch * max_tr - 0.3 * max_tr, f"Ch:{channel_ids[ch]}") + +# used_labels = [] +# for i, spike in enumerate(overlapping_spikes): +# label = f"U{spike['unit_index']}" +# if label in used_labels: +# label = None +# else: +# used_labels.append(label) +# ax.axvline( +# spike["sample_index"] / recording.sampling_frequency * 1000, color=f"C{spike['unit_index']}", label=label +# ) + +# if scalings is not None: +# fitted_traces = np.zeros_like(tr_overlap) + +# all_templates = we.get_all_templates() +# for i, spike in enumerate(overlapping_spikes): +# template = all_templates[spike["unit_index"]] +# overlap_index = np.where(spikes == spike)[0][0] +# template_scaled = scalings[overlap_index] * template +# template_scaled_sparse = template_scaled[:, sparse_indices] +# sample_start = spike["sample_index"] - we.nbefore +# sample_end = sample_start + template_scaled_sparse.shape[0] + +# fitted_traces[sample_start - sf : sample_end - sf] += template_scaled_sparse + +# for ch, temp in enumerate(template_scaled_sparse.T): +# ts_template = np.arange(sample_start, sample_end) / recording.sampling_frequency * 1000 +# _ = ax.plot(ts_template, temp + 1.2 * ch * max_tr, color=f"C{spike['unit_index']}", ls="--") + +# for ch, tr in enumerate(fitted_traces.T): +# _ = ax.plot(ts, tr + 1.2 * ch * max_tr, color="gray", alpha=0.7) + +# fitted_line = ax.get_lines()[-1] +# fitted_line.set_label("Fitted") + +# ax.legend() +# ax.set_title(f"Spike {spike_index} - sample {center_spike['sample_index']}") +# return ax diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 223122e927..c2ffcc6843 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -46,7 +46,7 @@ def __init__(self, recording, list_periods, mode="zeros", **random_chunk_kwargs) num_seg = recording.get_num_segments() if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and not np.isscalar(list_periods[0]): + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: # when unique segment accept list instead of of list of list/arrays list_periods = [list_periods]