Skip to content

Commit

Permalink
Feedback from Zach and Alessio better test for waveforms_tools
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 22, 2024
1 parent 1736b65 commit f0ec139
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 36 deletions.
25 changes: 11 additions & 14 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def fix_job_kwargs(runtime_job_kwargs):
runtime_job_kwargs = runtime_job_kwargs.copy()
runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process")
warnings.warn(
"job_kwargs: max_threads_per_worker was changed to max_threads_per_worker",
"job_kwargs: max_threads_per_process was changed to max_threads_per_worker, max_threads_per_process will be removed in 0.104",
DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -346,7 +346,7 @@ class ChunkRecordingExecutor:
gather_func : None or callable, default: None
Optional function that is called in the main thread and retrieves the results of each worker.
This function can be used instead of `handle_returns` to implement custom storage on-the-fly.
pool_engine : "process" | "thread"
pool_engine : "process" | "thread", default: "thread"
If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor
n_jobs : int, default: 1
Number of jobs to be used. Use -1 to use as many jobs as number of cores
Expand Down Expand Up @@ -384,7 +384,7 @@ def __init__(
progress_bar=False,
handle_returns=False,
gather_func=None,
pool_engine="process",
pool_engine="thread",
n_jobs=1,
total_memory=None,
chunk_size=None,
Expand All @@ -400,12 +400,13 @@ def __init__(
self.init_func = init_func
self.init_args = init_args

if mp_context is None:
mp_context = recording.get_preferred_mp_context()
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')
if pool_engine == "process":
if mp_context is None:
mp_context = recording.get_preferred_mp_context()
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')

self.mp_context = mp_context

Expand Down Expand Up @@ -572,13 +573,9 @@ def __call__(self, args):
else:
with threadpool_limits(limits=self.max_threads_per_worker):
return self.func(segment_index, start_frame, end_frame, self.worker_dict)

# see
# https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool
# the tricks is : thiw variables are global per worker
# so they are not share in the same process
# global _worker_ctx
# global _func
# the tricks is : this variable are global per worker (so not shared in the same process)
global _process_func_wrapper


Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def test_ensure_chunk_size():

def func(segment_index, start_frame, end_frame, worker_dict):
import os
import time

#  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid())
time.sleep(0.010)
Expand Down
58 changes: 37 additions & 21 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,29 +173,45 @@ def test_estimate_templates_with_accumulator():

job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")

templates = estimate_templates_with_accumulator(
recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs
)
# print(templates.shape)
assert templates.shape[0] == sorting.unit_ids.size
assert templates.shape[1] == nbefore + nafter
assert templates.shape[2] == recording.get_num_channels()
# here we compare the result with the same mechanism with with several worker pool size
# this means that that acumulator are splitted and then agglomerated back
# this should lead to very small diff
# n_jobs=1 is done in loop
templates_by_worker = []

if platform.system() == "Linux":
engine_loop = ["thread", "process"]
else:
engine_loop = ["thread"]

for pool_engine in engine_loop:
for n_jobs in (1, 2, 8):
job_kwargs = dict(pool_engine=pool_engine, n_jobs=n_jobs, progress_bar=True, chunk_duration="1s")
templates = estimate_templates_with_accumulator(
recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs
)
assert templates.shape[0] == sorting.unit_ids.size
assert templates.shape[1] == nbefore + nafter
assert templates.shape[2] == recording.get_num_channels()
assert np.any(templates != 0)

templates_by_worker.append(templates)
if len(templates_by_worker) > 1:
templates_loop = templates_by_worker[0]
np.testing.assert_almost_equal(templates, templates_loop, decimal=4)

# import matplotlib.pyplot as plt
# fig, axs = plt.subplots(nrows=2, sharex=True)
# for unit_index, unit_id in enumerate(sorting.unit_ids):
# ax = axs[0]
# ax.set_title(f"{pool_engine} {n_jobs}")
# ax.plot(templates[unit_index, :, :].T.flatten())
# ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--")
# ax = axs[1]
# ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--")
# plt.show()

assert np.any(templates != 0)

job_kwargs = dict(n_jobs=1, progress_bar=True, chunk_duration="1s")
templates_loop = estimate_templates_with_accumulator(
recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs
)
np.testing.assert_almost_equal(templates, templates_loop, decimal=4)

# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# for unit_index, unit_id in enumerate(sorting.unit_ids):
# ax.plot(templates[unit_index, :, :].T.flatten())
# ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--")
# ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--")
# plt.show()


def test_estimate_templates():
Expand Down

0 comments on commit f0ec139

Please sign in to comment.