diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 52d7472c92..1d7e38832a 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -59,14 +59,15 @@ def test_waveform_tools(): ] some_modes = [ {"mode": "memmap"}, + {"mode": "shared_memory"}, ] - if platform.system() != "Windows": - # shared memory on windows is buggy... - some_modes.append( - { - "mode": "shared_memory", - } - ) + # if platform.system() != "Windows": + # # shared memory on windows is buggy... + # some_modes.append( + # { + # "mode": "shared_memory", + # } + # ) some_sparsity = [ dict(sparsity_mask=None), @@ -87,9 +88,11 @@ def test_waveform_tools(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) wf_folder.mkdir(parents=True) - mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder) - else: - mode_kwargs_ = mode_kwargs + wf_file_path = wf_folder / "waveforms_all_units.npy" + + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["folder" ] = wf_folder wfs_arrays = extract_waveforms_to_buffers( recording, @@ -113,6 +116,10 @@ def test_waveform_tools(): else: list_wfs_sparse.append(wfs_arrays) + mode_kwargs_ = dict(**mode_kwargs) + if mode_kwargs["mode"] == "memmap": + mode_kwargs_["file_path" ] = wf_file_path + all_waveforms = extract_waveforms_to_single_buffer( recording, spikes, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 53c0df68df..c363ac49dc 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -406,7 +406,7 @@ def extract_waveforms_to_single_buffer( nafter, mode="memmap", return_scaled=False, - folder=None, + file_path=None, dtype=None, sparsity_mask=None, copy=False, @@ -442,8 +442,8 @@ def extract_waveforms_to_single_buffer( Mode to use ('memmap' | 'shared_memory') return_scaled: bool Scale traces before exporting to buffer or not. - folder: str or path - In case of memmap mode, folder to save npy files + file_path: str or path + In case of memmap mode, file to save npy file. dtype: numpy.dtype dtype for waveforms buffer sparsity_mask: None or array of bool @@ -468,9 +468,9 @@ def extract_waveforms_to_single_buffer( dtype = np.dtype(dtype) if mode == "shared_memory": - assert folder is None + assert file_path is None else: - folder = Path(folder) + file_path = Path(file_path) num_spikes = spikes.size if sparsity_mask is None: @@ -480,9 +480,8 @@ def extract_waveforms_to_single_buffer( shape = (num_spikes, nsamples, num_chans) if mode == "memmap": - filename = str(folder / f"waveforms.npy") - all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape) - wf_array_info = filename + all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape) + wf_array_info = str(file_path) elif mode == "shared_memory": if num_spikes == 0 or num_chans == 0: all_waveforms = np.zeros(shape, dtype=dtype) @@ -538,7 +537,6 @@ def _init_worker_distribute_single_buffer( worker_ctx = {} worker_ctx["recording"] = recording worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["unit_ids"] = unit_ids worker_ctx["spikes"] = spikes worker_ctx["nbefore"] = nbefore worker_ctx["nafter"] = nafter @@ -574,7 +572,6 @@ def _init_worker_distribute_single_buffer( def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): # recover variables of the worker recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] segment_slices = worker_ctx["segment_slices"] spikes = worker_ctx["spikes"] nbefore = worker_ctx["nbefore"]