Skip to content

Commit

Permalink
extract_waveforms_to_single_buffer change folder to file_path
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Sep 7, 2023
1 parent 3453831 commit d79dbe2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
27 changes: 17 additions & 10 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@ def test_waveform_tools():
]
some_modes = [
{"mode": "memmap"},
{"mode": "shared_memory"},
]
if platform.system() != "Windows":
# shared memory on windows is buggy...
some_modes.append(
{
"mode": "shared_memory",
}
)
# if platform.system() != "Windows":
# # shared memory on windows is buggy...
# some_modes.append(
# {
# "mode": "shared_memory",
# }
# )

some_sparsity = [
dict(sparsity_mask=None),
Expand All @@ -87,9 +88,11 @@ def test_waveform_tools():
if wf_folder.is_dir():
shutil.rmtree(wf_folder)
wf_folder.mkdir(parents=True)
mode_kwargs_ = dict(**mode_kwargs, folder=wf_folder)
else:
mode_kwargs_ = mode_kwargs
wf_file_path = wf_folder / "waveforms_all_units.npy"

mode_kwargs_ = dict(**mode_kwargs)
if mode_kwargs["mode"] == "memmap":
mode_kwargs_["folder" ] = wf_folder

wfs_arrays = extract_waveforms_to_buffers(
recording,
Expand All @@ -113,6 +116,10 @@ def test_waveform_tools():
else:
list_wfs_sparse.append(wfs_arrays)

mode_kwargs_ = dict(**mode_kwargs)
if mode_kwargs["mode"] == "memmap":
mode_kwargs_["file_path" ] = wf_file_path

all_waveforms = extract_waveforms_to_single_buffer(
recording,
spikes,
Expand Down
17 changes: 7 additions & 10 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def extract_waveforms_to_single_buffer(
nafter,
mode="memmap",
return_scaled=False,
folder=None,
file_path=None,
dtype=None,
sparsity_mask=None,
copy=False,
Expand Down Expand Up @@ -442,8 +442,8 @@ def extract_waveforms_to_single_buffer(
Mode to use ('memmap' | 'shared_memory')
return_scaled: bool
Scale traces before exporting to buffer or not.
folder: str or path
In case of memmap mode, folder to save npy files
file_path: str or path
In case of memmap mode, file to save npy file.
dtype: numpy.dtype
dtype for waveforms buffer
sparsity_mask: None or array of bool
Expand All @@ -468,9 +468,9 @@ def extract_waveforms_to_single_buffer(

dtype = np.dtype(dtype)
if mode == "shared_memory":
assert folder is None
assert file_path is None
else:
folder = Path(folder)
file_path = Path(file_path)

num_spikes = spikes.size
if sparsity_mask is None:
Expand All @@ -480,9 +480,8 @@ def extract_waveforms_to_single_buffer(
shape = (num_spikes, nsamples, num_chans)

if mode == "memmap":
filename = str(folder / f"waveforms.npy")
all_waveforms = np.lib.format.open_memmap(filename, mode="w+", dtype=dtype, shape=shape)
wf_array_info = filename
all_waveforms = np.lib.format.open_memmap(file_path, mode="w+", dtype=dtype, shape=shape)
wf_array_info = str(file_path)
elif mode == "shared_memory":
if num_spikes == 0 or num_chans == 0:
all_waveforms = np.zeros(shape, dtype=dtype)
Expand Down Expand Up @@ -538,7 +537,6 @@ def _init_worker_distribute_single_buffer(
worker_ctx = {}
worker_ctx["recording"] = recording
worker_ctx["wf_array_info"] = wf_array_info
worker_ctx["unit_ids"] = unit_ids
worker_ctx["spikes"] = spikes
worker_ctx["nbefore"] = nbefore
worker_ctx["nafter"] = nafter
Expand Down Expand Up @@ -574,7 +572,6 @@ def _init_worker_distribute_single_buffer(
def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx):
# recover variables of the worker
recording = worker_ctx["recording"]
unit_ids = worker_ctx["unit_ids"]
segment_slices = worker_ctx["segment_slices"]
spikes = worker_ctx["spikes"]
nbefore = worker_ctx["nbefore"]
Expand Down

0 comments on commit d79dbe2

Please sign in to comment.