Skip to content

Commit

Permalink
proof of concept of chunkexecutor with thread
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 8, 2024
1 parent 8a7895e commit e0ef39b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 40 deletions.
104 changes: 68 additions & 36 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sys
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
from threadpoolctl import threadpool_limits

Expand Down Expand Up @@ -329,6 +329,7 @@ def __init__(
progress_bar=False,
handle_returns=False,
gather_func=None,
pool_engine="process",
n_jobs=1,
total_memory=None,
chunk_size=None,
Expand Down Expand Up @@ -370,6 +371,8 @@ def __init__(
self.job_name = job_name
self.max_threads_per_process = max_threads_per_process

self.pool_engine = pool_engine

if verbose:
chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize
total_memory = chunk_memory * self.n_jobs
Expand Down Expand Up @@ -402,7 +405,7 @@ def run(self, recording_slices=None):

if self.n_jobs == 1:
if self.progress_bar:
recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name)
recording_slices = tqdm(recording_slices, desc=self.job_name, total=len(recording_slices))

worker_ctx = self.init_func(*self.init_args)
for segment_index, frame_start, frame_stop in recording_slices:
Expand All @@ -411,60 +414,89 @@ def run(self, recording_slices=None):
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

else:
n_jobs = min(self.n_jobs, len(recording_slices))

# parallel
with ProcessPoolExecutor(
max_workers=n_jobs,
initializer=worker_initializer,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(function_wrapper, recording_slices)
if self.pool_engine == "process":

# parallel
with ProcessPoolExecutor(
max_workers=n_jobs,
initializer=process_worker_initializer,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
) as executor:
results = executor.map(process_function_wrapper, recording_slices)

elif self.pool_engine == "thread":
# only one shared context

worker_dict = self.init_func(*self.init_args)
thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process)

with ThreadPoolExecutor(
max_workers=n_jobs,
) as executor:
results = executor.map(thread_func, recording_slices)


if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))
else:
raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'")


if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)


for res in results:
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)

return returns



class WorkerFuncWrapper:
def __init__(self, func, worker_dict, max_threads_per_process):
self.func = func
self.worker_dict = worker_dict
self.max_threads_per_process = max_threads_per_process

def __call__(self, args):
segment_index, start_frame, end_frame = args
if self.max_threads_per_process is None:
return self.func(segment_index, start_frame, end_frame, self.worker_dict)
else:
with threadpool_limits(limits=self.max_threads_per_process):
return self.func(segment_index, start_frame, end_frame, self.worker_dict)

# see
# https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool
# the tricks is : theses 2 variables are global per worker
# so they are not share in the same process
global _worker_ctx
global _func
# global _worker_ctx
# global _func
global _process_func_wrapper


def worker_initializer(func, init_func, init_args, max_threads_per_process):
global _worker_ctx
def process_worker_initializer(func, init_func, init_args, max_threads_per_process):
global _process_func_wrapper
if max_threads_per_process is None:
_worker_ctx = init_func(*init_args)
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
_worker_ctx = init_func(*init_args)
_worker_ctx["max_threads_per_process"] = max_threads_per_process
global _func
_func = func
worker_dict = init_func(*init_args)
_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)


def function_wrapper(args):
segment_index, start_frame, end_frame = args
global _func
global _worker_ctx
max_threads_per_process = _worker_ctx["max_threads_per_process"]
if max_threads_per_process is None:
return _func(segment_index, start_frame, end_frame, _worker_ctx)
else:
with threadpool_limits(limits=max_threads_per_process):
return _func(segment_index, start_frame, end_frame, _worker_ctx)
def process_function_wrapper(args):
global _process_func_wrapper
return _process_func_wrapper(args)



# Here some utils copy/paste from DART (Charlie Windolf)
Expand Down
25 changes: 21 additions & 4 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __call__(self, res):

gathering_func2 = GatherClass()

# chunk + parallel + gather_func
# process + gather_func
processor = ChunkRecordingExecutor(
recording,
func,
Expand All @@ -148,6 +148,7 @@ def __call__(self, res):
verbose=True,
progress_bar=True,
gather_func=gathering_func2,
pool_engine="process",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
Expand All @@ -157,21 +158,37 @@ def __call__(self, res):

assert gathering_func2.pos == num_chunks

# chunk + parallel + spawn
# process spawn
processor = ChunkRecordingExecutor(
recording,
func,
init_func,
init_args,
verbose=True,
progress_bar=True,
pool_engine="process",
mp_context="spawn",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
)
processor.run()

# thread
processor = ChunkRecordingExecutor(
recording,
func,
init_func,
init_args,
verbose=True,
progress_bar=True,
pool_engine="thread",
n_jobs=2,
chunk_duration="200ms",
job_name="job_name",
)
processor.run()


def test_fix_job_kwargs():
# test negative n_jobs
Expand Down Expand Up @@ -224,6 +241,6 @@ def test_split_job_kwargs():
# test_divide_segment_into_chunks()
# test_ensure_n_jobs()
# test_ensure_chunk_size()
# test_ChunkRecordingExecutor()
test_fix_job_kwargs()
test_ChunkRecordingExecutor()
# test_fix_job_kwargs()
# test_split_job_kwargs()

0 comments on commit e0ef39b

Please sign in to comment.