Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 12, 2024
1 parent 67b055b commit cecb211
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
import threading
from threadpoolctl import threadpool_limits


Expand Down Expand Up @@ -445,13 +446,18 @@ def run(self, recording_slices=None):
elif self.pool_engine == "thread":
# only one shared context

worker_dict = self.init_func(*self.init_args)
thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process)
# worker_dict = self.init_func(*self.init_args)
# thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process)

thread_data = threading.local()

with ThreadPoolExecutor(
max_workers=n_jobs,
initializer=thread_worker_initializer,
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_data),
) as executor:
results = executor.map(thread_func, recording_slices)
recording_slices2 = [(thread_data, ) + args for args in recording_slices]
results = executor.map(thread_function_wrapper, recording_slices2)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))
Expand Down Expand Up @@ -485,7 +491,7 @@ def __call__(self, args):

# see
# https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool
# the tricks is : theses 2 variables are global per worker
# the tricks is : thiw variables are global per worker
# so they are not share in the same process
# global _worker_ctx
# global _func
Expand All @@ -501,11 +507,28 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_proce
worker_dict = init_func(*init_args)
_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)


def process_function_wrapper(args):
global _process_func_wrapper
return _process_func_wrapper(args)

def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_data):
if max_threads_per_process is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
worker_dict = init_func(*init_args)
thread_data._func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)
# print("ici", thread_data._func_wrapper)

def thread_function_wrapper(args):
thread_data = args[0]
args = args[1:]
# thread_data = threading.local()
# print("la", thread_data._func_wrapper)
return thread_data._func_wrapper(args)





# Here some utils copy/paste from DART (Charlie Windolf)
Expand Down

0 comments on commit cecb211

Please sign in to comment.