From cecb211b4482af487dd0278fa2fd5e67f2efb0bf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Nov 2024 13:55:49 +0100 Subject: [PATCH] wip --- src/spikeinterface/core/job_tools.py | 33 +++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index db23a78b31..c514d4c74e 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -14,6 +14,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor import multiprocessing as mp +import threading from threadpoolctl import threadpool_limits @@ -445,13 +446,18 @@ def run(self, recording_slices=None): elif self.pool_engine == "thread": # only one shared context - worker_dict = self.init_func(*self.init_args) - thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process) + # worker_dict = self.init_func(*self.init_args) + # thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process) + + thread_data = threading.local() with ThreadPoolExecutor( max_workers=n_jobs, + initializer=thread_worker_initializer, + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_data), ) as executor: - results = executor.map(thread_func, recording_slices) + recording_slices2 = [(thread_data, ) + args for args in recording_slices] + results = executor.map(thread_function_wrapper, recording_slices2) if self.progress_bar: results = tqdm(results, desc=self.job_name, total=len(recording_slices)) @@ -485,7 +491,7 @@ def __call__(self, args): # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : theses 2 variables are global per worker +# the tricks is : thiw variables are global per worker # so they are not share in the same process # global _worker_ctx # global _func @@ -501,11 +507,28 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_proce worker_dict = init_func(*init_args) _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) - def process_function_wrapper(args): global _process_func_wrapper return _process_func_wrapper(args) +def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_data): + if max_threads_per_process is None: + worker_dict = init_func(*init_args) + else: + with threadpool_limits(limits=max_threads_per_process): + worker_dict = init_func(*init_args) + thread_data._func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) + # print("ici", thread_data._func_wrapper) + +def thread_function_wrapper(args): + thread_data = args[0] + args = args[1:] + # thread_data = threading.local() + # print("la", thread_data._func_wrapper) + return thread_data._func_wrapper(args) + + + # Here some utils copy/paste from DART (Charlie Windolf)