From d4a6e95d1c9f6a7d5cb9d5f4ca017a2240c187ad Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Nov 2024 14:35:06 +0100 Subject: [PATCH] implement get_best_job_kwargs() --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/job_tools.py | 39 +++++++++++++++++++ .../core/tests/test_job_tools.py | 9 ++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index ead7007920..bea77decfc 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -90,7 +90,7 @@ write_python, normal_pdf, ) -from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs +from .job_tools import get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index b37c9b7d69..b12ad7fc4d 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -59,6 +59,45 @@ "chunk_duration", ) +def get_best_job_kwargs(): + """ + Given best possible job_kwargs for the platform. + """ + + n_cpu = os.cpu_count() + + if platform.system() == "Linux": + # maybe we should test this more but with linux the fork is still faster than threading + pool_engine = "process" + mp_context = "fork" + + # this is totally empiricat but this is a good start + if n_cpu <= 16: + # for small n_cpu lets make many process + n_jobs = n_cpu + max_threads_per_worker = 1 + else: + # lets have less process with more thread each + n_cpu = int(n_cpu / 4) + max_threads_per_worker = 8 + + else: # windows and mac + # on windows and macos the fork is forbidden and process+spwan is super slow at startup + # so lets go to threads + pool_engine = "thread" + mp_context = None + n_jobs = n_cpu + max_threads_per_worker = 1 + + return dict( + pool_engine=pool_engine, + mp_context=mp_context, + n_jobs=n_jobs, + max_threads_per_worker=max_threads_per_worker, + ) + + + def fix_job_kwargs(runtime_job_kwargs): from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 8872a259bf..3918fe8ec0 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -3,7 +3,7 @@ import time -from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs +from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs, get_best_job_kwargs from spikeinterface.core.job_tools import ( divide_segment_into_chunks, @@ -277,11 +277,16 @@ def test_worker_index(): assert 0 in res assert 1 in res +def test_get_best_job_kwargs(): + job_kwargs = get_best_job_kwargs() + print(job_kwargs) + if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - test_ChunkRecordingExecutor() + # test_ChunkRecordingExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() # test_worker_index() + test_get_best_job_kwargs()