From f0ec139fdd52b43048becd9323d7475a6d41097e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Nov 2024 10:06:37 +0100 Subject: [PATCH] Feedback from Zach and Alessio better test for waveforms_tools --- src/spikeinterface/core/job_tools.py | 25 ++++---- .../core/tests/test_job_tools.py | 1 - .../core/tests/test_waveform_tools.py | 58 ++++++++++++------- 3 files changed, 48 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index b12ad7fc4d..64a5c6cdbf 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -110,7 +110,7 @@ def fix_job_kwargs(runtime_job_kwargs): runtime_job_kwargs = runtime_job_kwargs.copy() runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process") warnings.warn( - "job_kwargs: max_threads_per_worker was changed to max_threads_per_worker", + "job_kwargs: max_threads_per_process was changed to max_threads_per_worker, max_threads_per_process will be removed in 0.104", DeprecationWarning, stacklevel=2, ) @@ -346,7 +346,7 @@ class ChunkRecordingExecutor: gather_func : None or callable, default: None Optional function that is called in the main thread and retrieves the results of each worker. This function can be used instead of `handle_returns` to implement custom storage on-the-fly. - pool_engine : "process" | "thread" + pool_engine : "process" | "thread", default: "thread" If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor n_jobs : int, default: 1 Number of jobs to be used. Use -1 to use as many jobs as number of cores @@ -384,7 +384,7 @@ def __init__( progress_bar=False, handle_returns=False, gather_func=None, - pool_engine="process", + pool_engine="thread", n_jobs=1, total_memory=None, chunk_size=None, @@ -400,12 +400,13 @@ def __init__( self.init_func = init_func self.init_args = init_args - if mp_context is None: - mp_context = recording.get_preferred_mp_context() - if mp_context is not None and platform.system() == "Windows": - assert mp_context != "fork", "'fork' mp_context not supported on Windows!" - elif mp_context == "fork" and platform.system() == "Darwin": - warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + if pool_engine == "process": + if mp_context is None: + mp_context = recording.get_preferred_mp_context() + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') self.mp_context = mp_context @@ -572,13 +573,9 @@ def __call__(self, args): else: with threadpool_limits(limits=self.max_threads_per_worker): return self.func(segment_index, start_frame, end_frame, self.worker_dict) - # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : thiw variables are global per worker -# so they are not share in the same process -# global _worker_ctx -# global _func +# the tricks is : this variable are global per worker (so not shared in the same process) global _process_func_wrapper diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 3918fe8ec0..824532a11e 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -81,7 +81,6 @@ def test_ensure_chunk_size(): def func(segment_index, start_frame, end_frame, worker_dict): import os - import time #  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid()) time.sleep(0.010) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index d0e9358164..ed27815758 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -173,29 +173,45 @@ def test_estimate_templates_with_accumulator(): job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - templates = estimate_templates_with_accumulator( - recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs - ) - # print(templates.shape) - assert templates.shape[0] == sorting.unit_ids.size - assert templates.shape[1] == nbefore + nafter - assert templates.shape[2] == recording.get_num_channels() + # here we compare the result with the same mechanism with with several worker pool size + # this means that that acumulator are splitted and then agglomerated back + # this should lead to very small diff + # n_jobs=1 is done in loop + templates_by_worker = [] + + if platform.system() == "Linux": + engine_loop = ["thread", "process"] + else: + engine_loop = ["thread"] + + for pool_engine in engine_loop: + for n_jobs in (1, 2, 8): + job_kwargs = dict(pool_engine=pool_engine, n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + templates = estimate_templates_with_accumulator( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs + ) + assert templates.shape[0] == sorting.unit_ids.size + assert templates.shape[1] == nbefore + nafter + assert templates.shape[2] == recording.get_num_channels() + assert np.any(templates != 0) + + templates_by_worker.append(templates) + if len(templates_by_worker) > 1: + templates_loop = templates_by_worker[0] + np.testing.assert_almost_equal(templates, templates_loop, decimal=4) + + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, sharex=True) + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # ax = axs[0] + # ax.set_title(f"{pool_engine} {n_jobs}") + # ax.plot(templates[unit_index, :, :].T.flatten()) + # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") + # ax = axs[1] + # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") + # plt.show() - assert np.any(templates != 0) - job_kwargs = dict(n_jobs=1, progress_bar=True, chunk_duration="1s") - templates_loop = estimate_templates_with_accumulator( - recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs - ) - np.testing.assert_almost_equal(templates, templates_loop, decimal=4) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # for unit_index, unit_id in enumerate(sorting.unit_ids): - # ax.plot(templates[unit_index, :, :].T.flatten()) - # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") - # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") - # plt.show() def test_estimate_templates():