Skip to content

Commit

Permalink
change max_threads_per_process to max_threads_per_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 20, 2024
1 parent 4365321 commit e929820
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 55 deletions.
2 changes: 1 addition & 1 deletion doc/get_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ available parameters are dictionaries and can be accessed with:
'detect_threshold': 5,
'freq_max': 5000.0,
'freq_min': 400.0,
'max_threads_per_process': 1,
'max_threads_per_worker': 1,
'mp_context': None,
'n_jobs': 20,
'nested_params': None,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def is_set_global_dataset_folder() -> bool:


########################################
_default_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
_default_job_kwargs = dict(pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1)

global global_job_kwargs
global_job_kwargs = _default_job_kwargs.copy()
Expand Down
52 changes: 32 additions & 20 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"chunk_duration",
"progress_bar",
"mp_context",
"max_threads_per_process",
"max_threads_per_worker",
)

# theses key are the same and should not be in th final dict
Expand All @@ -65,6 +65,17 @@ def fix_job_kwargs(runtime_job_kwargs):

job_kwargs = get_global_job_kwargs()

# deprecation with backward compatibility
# this can be removed in 0.104.0
if "max_threads_per_process" in runtime_job_kwargs:
runtime_job_kwargs = runtime_job_kwargs.copy()
runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process")
warnings.warn(
"job_kwargs: max_threads_per_worker was changed to max_threads_per_worker",
DeprecationWarning,
stacklevel=2,
)

for k in runtime_job_kwargs:
assert k in job_keys, (
f"{k} is not a valid job keyword argument. " f"Available keyword arguments are: {list(job_keys)}"
Expand Down Expand Up @@ -311,7 +322,7 @@ class ChunkRecordingExecutor:
mp_context : "fork" | "spawn" | None, default: None
"fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context().
"fork" is only safely available on LINUX systems.
max_threads_per_process : int or None, default: None
max_threads_per_worker : int or None, default: None
Limit the number of thread per process using threadpoolctl modules.
This used only when n_jobs>1
If None, no limits.
Expand Down Expand Up @@ -342,7 +353,7 @@ def __init__(
chunk_duration=None,
mp_context=None,
job_name="",
max_threads_per_process=1,
max_threads_per_worker=1,
need_worker_index=False,
):
self.recording = recording
Expand Down Expand Up @@ -375,7 +386,7 @@ def __init__(
n_jobs=self.n_jobs,
)
self.job_name = job_name
self.max_threads_per_process = max_threads_per_process
self.max_threads_per_worker = max_threads_per_worker

self.pool_engine = pool_engine

Expand Down Expand Up @@ -446,7 +457,7 @@ def run(self, recording_slices=None):
max_workers=n_jobs,
initializer=process_worker_initializer,
mp_context=multiprocessing.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, self.need_worker_index, lock, array_pid),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, self.need_worker_index, lock, array_pid),
) as executor:
results = executor.map(process_function_wrapper, recording_slices)

Expand All @@ -473,12 +484,13 @@ def run(self, recording_slices=None):

if self.need_worker_index:
lock = threading.Lock()
thread_started = 0
else:
lock = None

with ThreadPoolExecutor(
max_workers=n_jobs,
initializer=thread_worker_initializer,
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_local_data, self.need_worker_index, lock),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, thread_local_data, self.need_worker_index, lock),
) as executor:


Expand Down Expand Up @@ -507,19 +519,19 @@ class WorkerFuncWrapper:
"""
small wraper that handle:
* local worker_dict
* max_threads_per_process
* max_threads_per_worker
"""
def __init__(self, func, worker_dict, max_threads_per_process):
def __init__(self, func, worker_dict, max_threads_per_worker):
self.func = func
self.worker_dict = worker_dict
self.max_threads_per_process = max_threads_per_process
self.max_threads_per_worker = max_threads_per_worker

def __call__(self, args):
segment_index, start_frame, end_frame = args
if self.max_threads_per_process is None:
if self.max_threads_per_worker is None:
return self.func(segment_index, start_frame, end_frame, self.worker_dict)
else:
with threadpool_limits(limits=self.max_threads_per_process):
with threadpool_limits(limits=self.max_threads_per_worker):
return self.func(segment_index, start_frame, end_frame, self.worker_dict)

# see
Expand All @@ -531,12 +543,12 @@ def __call__(self, args):
global _process_func_wrapper


def process_worker_initializer(func, init_func, init_args, max_threads_per_process, need_worker_index, lock, array_pid):
def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid):
global _process_func_wrapper
if max_threads_per_process is None:
if max_threads_per_worker is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
with threadpool_limits(limits=max_threads_per_worker):
worker_dict = init_func(*init_args)

if need_worker_index:
Expand All @@ -551,7 +563,7 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_proce
worker_dict["worker_index"] = worker_index
lock.release()

_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)
_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker)

def process_function_wrapper(args):
global _process_func_wrapper
Expand All @@ -561,11 +573,11 @@ def process_function_wrapper(args):
# use by thread at init
global _thread_started

def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_local_data, need_worker_index, lock):
if max_threads_per_process is None:
def thread_worker_initializer(func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock):
if max_threads_per_worker is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
with threadpool_limits(limits=max_threads_per_worker):
worker_dict = init_func(*init_args)

if need_worker_index:
Expand All @@ -576,7 +588,7 @@ def thread_worker_initializer(func, init_func, init_args, max_threads_per_proces
worker_dict["worker_index"] = worker_index
lock.release()

thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)
thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker)

def thread_function_wrapper(args):
thread_local_data = args[0]
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/tests/test_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_global_tmp_folder(create_cache_folder):


def test_global_job_kwargs():
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1)
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1)
global_job_kwargs = get_global_job_kwargs()

# test warning when not setting n_jobs and calling fix_job_kwargs
with pytest.warns(UserWarning):
job_kwargs_split = fix_job_kwargs({})

assert global_job_kwargs == dict(
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1
)
set_global_job_kwargs(**job_kwargs)
assert get_global_job_kwargs() == job_kwargs
Expand All @@ -59,7 +59,7 @@ def test_global_job_kwargs():
set_global_job_kwargs(**partial_job_kwargs)
global_job_kwargs = get_global_job_kwargs()
assert global_job_kwargs == dict(
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1
n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1
)
# test that fix_job_kwargs grabs global kwargs
new_job_kwargs = dict(n_jobs=cpu_count())
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_worker_index():
# test_divide_segment_into_chunks()
# test_ensure_n_jobs()
# test_ensure_chunk_size()
# test_ChunkRecordingExecutor()
test_ChunkRecordingExecutor()
# test_fix_job_kwargs()
# test_split_job_kwargs()
test_worker_index()
# test_worker_index()
16 changes: 8 additions & 8 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,13 @@ def _run(self, verbose=False, **job_kwargs):
job_kwargs = fix_job_kwargs(job_kwargs)
n_jobs = job_kwargs["n_jobs"]
progress_bar = job_kwargs["progress_bar"]
max_threads_per_process = job_kwargs["max_threads_per_process"]
max_threads_per_worker = job_kwargs["max_threads_per_worker"]
mp_context = job_kwargs["mp_context"]

# fit model/models
# TODO : make parralel for by_channel_global and concatenated
if mode == "by_channel_local":
pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context)
pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_worker, mp_context)
for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids):
self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind]
pca_model = pca_models
Expand Down Expand Up @@ -415,7 +415,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):
)
processor.run()

def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context):
def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_worker, mp_context):
from sklearn.decomposition import IncrementalPCA

p = self.params
Expand Down Expand Up @@ -444,10 +444,10 @@ def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, m
pca = pca_models[chan_ind]
pca.partial_fit(wfs[:, :, wf_ind])
else:
# create list of args to parallelize. For convenience, the max_threads_per_process is passed
# create list of args to parallelize. For convenience, the max_threads_per_worker is passed
# as last argument
items = [
(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process)
(chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_worker)
for wf_ind, chan_ind in enumerate(channel_inds)
]
n_jobs = min(n_jobs, len(items))
Expand Down Expand Up @@ -687,12 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte


def _partial_fit_one_channel(args):
chan_ind, pca_model, wf_chan, max_threads_per_process = args
chan_ind, pca_model, wf_chan, max_threads_per_worker = args

if max_threads_per_process is None:
if max_threads_per_worker is None:
pca_model.partial_fit(wf_chan)
return chan_ind, pca_model
else:
with threadpool_limits(limits=int(max_threads_per_process)):
with threadpool_limits(limits=int(max_threads_per_worker)):
pca_model.partial_fit(wf_chan)
return chan_ind, pca_model
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_multi_processing(self):
)
sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2)
sorting_analyzer.compute(
"principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn"
"principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_worker=4, mp_context="spawn"
)

def test_mode_concatenated(self):
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def compute_pc_metrics(
n_jobs=1,
progress_bar=False,
mp_context=None,
max_threads_per_process=None,
max_threads_per_worker=None,
) -> dict:
"""
Calculate principal component derived metrics.
Expand Down Expand Up @@ -147,7 +147,7 @@ def compute_pc_metrics(
pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
pcs_flat = pcs.reshape(pcs.shape[0], -1)

func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process)
func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_worker)
items.append(func_args)

if not run_in_parallel and non_nn_metrics:
Expand Down Expand Up @@ -977,12 +977,12 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):


def pca_metrics_one_unit(args):
(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args
(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_worker) = args

if max_threads_per_process is None:
if max_threads_per_worker is None:
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
else:
with threadpool_limits(limits=int(max_threads_per_process)):
with threadpool_limits(limits=int(max_threads_per_worker)):
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)


Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer):

print(f"Computing PCA metrics with 1 thread per process")
res1 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=1, progress_bar=True
)
print(f"Computing PCA metrics with 2 thread per process")
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True
)
print("Computing PCA metrics with spawn context")
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True
)
12 changes: 6 additions & 6 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def find_merge_pairs(
**job_kwargs,
# n_jobs=1,
# mp_context="fork",
# max_threads_per_process=1,
# max_threads_per_worker=1,
# progress_bar=True,
):
"""
Expand Down Expand Up @@ -299,7 +299,7 @@ def find_merge_pairs(

n_jobs = job_kwargs["n_jobs"]
mp_context = job_kwargs.get("mp_context", None)
max_threads_per_process = job_kwargs.get("max_threads_per_process", 1)
max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1)
progress_bar = job_kwargs["progress_bar"]

Executor = get_poolexecutor(n_jobs)
Expand All @@ -316,7 +316,7 @@ def find_merge_pairs(
templates,
method,
method_kwargs,
max_threads_per_process,
max_threads_per_worker,
),
) as pool:
jobs = []
Expand Down Expand Up @@ -354,7 +354,7 @@ def find_pair_worker_init(
templates,
method,
method_kwargs,
max_threads_per_process,
max_threads_per_worker,
):
global _ctx
_ctx = {}
Expand All @@ -366,7 +366,7 @@ def find_pair_worker_init(
_ctx["method"] = method
_ctx["method_kwargs"] = method_kwargs
_ctx["method_class"] = find_pair_method_dict[method]
_ctx["max_threads_per_process"] = max_threads_per_process
_ctx["max_threads_per_worker"] = max_threads_per_worker

# if isinstance(features_dict_or_folder, dict):
# _ctx["features"] = features_dict_or_folder
Expand All @@ -380,7 +380,7 @@ def find_pair_worker_init(

def find_pair_function_wrapper(label0, label1):
global _ctx
with threadpool_limits(limits=_ctx["max_threads_per_process"]):
with threadpool_limits(limits=_ctx["max_threads_per_worker"]):
is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge(
label0,
label1,
Expand Down
Loading

0 comments on commit e929820

Please sign in to comment.