Skip to content

Commit

Permalink
implement get_best_job_kwargs()
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 20, 2024
1 parent e929820 commit d4a6e95
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
write_python,
normal_pdf,
)
from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs
from .job_tools import get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs
from .recording_tools import (
write_binary_recording,
write_to_h5_dataset_format,
Expand Down
39 changes: 39 additions & 0 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,45 @@
"chunk_duration",
)

def get_best_job_kwargs():
"""
Given best possible job_kwargs for the platform.
"""

n_cpu = os.cpu_count()

if platform.system() == "Linux":
# maybe we should test this more but with linux the fork is still faster than threading
pool_engine = "process"
mp_context = "fork"

# this is totally empiricat but this is a good start
if n_cpu <= 16:
# for small n_cpu lets make many process
n_jobs = n_cpu
max_threads_per_worker = 1
else:
# lets have less process with more thread each
n_cpu = int(n_cpu / 4)
max_threads_per_worker = 8

else: # windows and mac
# on windows and macos the fork is forbidden and process+spwan is super slow at startup
# so lets go to threads
pool_engine = "thread"
mp_context = None
n_jobs = n_cpu
max_threads_per_worker = 1

return dict(
pool_engine=pool_engine,
mp_context=mp_context,
n_jobs=n_jobs,
max_threads_per_worker=max_threads_per_worker,
)




def fix_job_kwargs(runtime_job_kwargs):
from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set
Expand Down
9 changes: 7 additions & 2 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import time

from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs
from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs, get_best_job_kwargs

from spikeinterface.core.job_tools import (
divide_segment_into_chunks,
Expand Down Expand Up @@ -277,11 +277,16 @@ def test_worker_index():
assert 0 in res
assert 1 in res

def test_get_best_job_kwargs():
job_kwargs = get_best_job_kwargs()
print(job_kwargs)

if __name__ == "__main__":
# test_divide_segment_into_chunks()
# test_ensure_n_jobs()
# test_ensure_chunk_size()
test_ChunkRecordingExecutor()
# test_ChunkRecordingExecutor()
# test_fix_job_kwargs()
# test_split_job_kwargs()
# test_worker_index()
test_get_best_job_kwargs()

0 comments on commit d4a6e95

Please sign in to comment.