diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3d45606a78..d1bf311340 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -287,7 +287,7 @@ available parameters are dictionaries and can be accessed with: 'detect_threshold': 5, 'freq_max': 5000.0, 'freq_min': 400.0, - 'max_threads_per_process': 1, + 'max_threads_per_worker': 1, 'mp_context': None, 'n_jobs': 20, 'nested_params': None, diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 38f39c5481..195440c061 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -97,7 +97,7 @@ def is_set_global_dataset_folder() -> bool: ######################################## -_default_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) +_default_job_kwargs = dict(pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global global_job_kwargs global_job_kwargs = _default_job_kwargs.copy() diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 2a4af1288c..b37c9b7d69 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -48,7 +48,7 @@ "chunk_duration", "progress_bar", "mp_context", - "max_threads_per_process", + "max_threads_per_worker", ) # theses key are the same and should not be in th final dict @@ -65,6 +65,17 @@ def fix_job_kwargs(runtime_job_kwargs): job_kwargs = get_global_job_kwargs() + # deprecation with backward compatibility + # this can be removed in 0.104.0 + if "max_threads_per_process" in runtime_job_kwargs: + runtime_job_kwargs = runtime_job_kwargs.copy() + runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process") + warnings.warn( + "job_kwargs: max_threads_per_worker was changed to max_threads_per_worker", + DeprecationWarning, + stacklevel=2, + ) + for k in runtime_job_kwargs: assert k in job_keys, ( f"{k} is not a valid job keyword argument. " f"Available keyword arguments are: {list(job_keys)}" @@ -311,7 +322,7 @@ class ChunkRecordingExecutor: mp_context : "fork" | "spawn" | None, default: None "fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context(). "fork" is only safely available on LINUX systems. - max_threads_per_process : int or None, default: None + max_threads_per_worker : int or None, default: None Limit the number of thread per process using threadpoolctl modules. This used only when n_jobs>1 If None, no limits. @@ -342,7 +353,7 @@ def __init__( chunk_duration=None, mp_context=None, job_name="", - max_threads_per_process=1, + max_threads_per_worker=1, need_worker_index=False, ): self.recording = recording @@ -375,7 +386,7 @@ def __init__( n_jobs=self.n_jobs, ) self.job_name = job_name - self.max_threads_per_process = max_threads_per_process + self.max_threads_per_worker = max_threads_per_worker self.pool_engine = pool_engine @@ -446,7 +457,7 @@ def run(self, recording_slices=None): max_workers=n_jobs, initializer=process_worker_initializer, mp_context=multiprocessing.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, self.need_worker_index, lock, array_pid), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, self.need_worker_index, lock, array_pid), ) as executor: results = executor.map(process_function_wrapper, recording_slices) @@ -473,12 +484,13 @@ def run(self, recording_slices=None): if self.need_worker_index: lock = threading.Lock() - thread_started = 0 + else: + lock = None with ThreadPoolExecutor( max_workers=n_jobs, initializer=thread_worker_initializer, - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_local_data, self.need_worker_index, lock), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, thread_local_data, self.need_worker_index, lock), ) as executor: @@ -507,19 +519,19 @@ class WorkerFuncWrapper: """ small wraper that handle: * local worker_dict - * max_threads_per_process + * max_threads_per_worker """ - def __init__(self, func, worker_dict, max_threads_per_process): + def __init__(self, func, worker_dict, max_threads_per_worker): self.func = func self.worker_dict = worker_dict - self.max_threads_per_process = max_threads_per_process + self.max_threads_per_worker = max_threads_per_worker def __call__(self, args): segment_index, start_frame, end_frame = args - if self.max_threads_per_process is None: + if self.max_threads_per_worker is None: return self.func(segment_index, start_frame, end_frame, self.worker_dict) else: - with threadpool_limits(limits=self.max_threads_per_process): + with threadpool_limits(limits=self.max_threads_per_worker): return self.func(segment_index, start_frame, end_frame, self.worker_dict) # see @@ -531,12 +543,12 @@ def __call__(self, args): global _process_func_wrapper -def process_worker_initializer(func, init_func, init_args, max_threads_per_process, need_worker_index, lock, array_pid): +def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid): global _process_func_wrapper - if max_threads_per_process is None: + if max_threads_per_worker is None: worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): + with threadpool_limits(limits=max_threads_per_worker): worker_dict = init_func(*init_args) if need_worker_index: @@ -551,7 +563,7 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_proce worker_dict["worker_index"] = worker_index lock.release() - _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) + _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) def process_function_wrapper(args): global _process_func_wrapper @@ -561,11 +573,11 @@ def process_function_wrapper(args): # use by thread at init global _thread_started -def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_local_data, need_worker_index, lock): - if max_threads_per_process is None: +def thread_worker_initializer(func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock): + if max_threads_per_worker is None: worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): + with threadpool_limits(limits=max_threads_per_worker): worker_dict = init_func(*init_args) if need_worker_index: @@ -576,7 +588,7 @@ def thread_worker_initializer(func, init_func, init_args, max_threads_per_proces worker_dict["worker_index"] = worker_index lock.release() - thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) + thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) def thread_function_wrapper(args): thread_local_data = args[0] diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 9677378fc5..2b21cd8978 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -36,7 +36,7 @@ def test_global_tmp_folder(create_cache_folder): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global_job_kwargs = get_global_job_kwargs() # test warning when not setting n_jobs and calling fix_job_kwargs @@ -44,7 +44,7 @@ def test_global_job_kwargs(): job_kwargs_split = fix_job_kwargs({}) assert global_job_kwargs == dict( - n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs @@ -59,7 +59,7 @@ def test_global_job_kwargs(): set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() assert global_job_kwargs == dict( - n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=cpu_count()) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 5a32898411..8872a259bf 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -281,7 +281,7 @@ def test_worker_index(): # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - # test_ChunkRecordingExecutor() + test_ChunkRecordingExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() - test_worker_index() + # test_worker_index() diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 809f2c5bba..84fbfc5965 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -316,13 +316,13 @@ def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] mp_context = job_kwargs["mp_context"] # fit model/models # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": - pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context) + pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_worker, mp_context) for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models @@ -415,7 +415,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): ) processor.run() - def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context): + def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_worker, mp_context): from sklearn.decomposition import IncrementalPCA p = self.params @@ -444,10 +444,10 @@ def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, m pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - # create list of args to parallelize. For convenience, the max_threads_per_process is passed + # create list of args to parallelize. For convenience, the max_threads_per_worker is passed # as last argument items = [ - (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process) + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_worker) for wf_ind, chan_ind in enumerate(channel_inds) ] n_jobs = min(n_jobs, len(items)) @@ -687,12 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte def _partial_fit_one_channel(args): - chan_ind, pca_model, wf_chan, max_threads_per_process = args + chan_ind, pca_model, wf_chan, max_threads_per_worker = args - if max_threads_per_process is None: + if max_threads_per_worker is None: pca_model.partial_fit(wf_chan) return chan_ind, pca_model else: - with threadpool_limits(limits=int(max_threads_per_process)): + with threadpool_limits(limits=int(max_threads_per_worker)): pca_model.partial_fit(wf_chan) return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 7a509c410f..ecfc39f2c6 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -27,7 +27,7 @@ def test_multi_processing(self): ) sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) sorting_analyzer.compute( - "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn" + "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_worker=4, mp_context="spawn" ) def test_mode_concatenated(self): diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 4c68dfea59..55f91fd87f 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -58,7 +58,7 @@ def compute_pc_metrics( n_jobs=1, progress_bar=False, mp_context=None, - max_threads_per_process=None, + max_threads_per_worker=None, ) -> dict: """ Calculate principal component derived metrics. @@ -147,7 +147,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_worker) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -977,12 +977,12 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_worker) = args - if max_threads_per_process is None: + if max_threads_per_worker is None: return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) else: - with threadpool_limits(limits=int(max_threads_per_process)): + with threadpool_limits(limits=int(max_threads_per_worker)): return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index f2e912c6b4..ba8dae4619 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -31,13 +31,13 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer): print(f"Computing PCA metrics with 1 thread per process") res1 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=1, progress_bar=True ) print(f"Computing PCA metrics with 2 thread per process") res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) print("Computing PCA metrics with spawn context") res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4a7b722aea..e618cfbfb6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -261,7 +261,7 @@ def find_merge_pairs( **job_kwargs, # n_jobs=1, # mp_context="fork", - # max_threads_per_process=1, + # max_threads_per_worker=1, # progress_bar=True, ): """ @@ -299,7 +299,7 @@ def find_merge_pairs( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) progress_bar = job_kwargs["progress_bar"] Executor = get_poolexecutor(n_jobs) @@ -316,7 +316,7 @@ def find_merge_pairs( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ), ) as pool: jobs = [] @@ -354,7 +354,7 @@ def find_pair_worker_init( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ): global _ctx _ctx = {} @@ -366,7 +366,7 @@ def find_pair_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker # if isinstance(features_dict_or_folder, dict): # _ctx["features"] = features_dict_or_folder @@ -380,7 +380,7 @@ def find_pair_worker_init( def find_pair_function_wrapper(label0, label1): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( label0, label1, diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 15917934a8..3c2e878c39 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -65,7 +65,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) original_labels = peak_labels peak_labels = peak_labels.copy() @@ -77,7 +77,7 @@ def split_clusters( max_workers=n_jobs, initializer=split_worker_init, mp_context=get_context(method=mp_context), - initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker), ) as pool: labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 @@ -133,7 +133,7 @@ def split_clusters( def split_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker ): global _ctx _ctx = {} @@ -144,14 +144,14 @@ def split_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = split_methods_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) _ctx["peaks"] = _ctx["features"]["peaks"] def split_function_wrapper(peak_indices, recursion_level): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_split, local_labels = _ctx["method_class"].split( peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, **_ctx["method_kwargs"] )