Skip to content

Commit

Permalink
Move worker_index to job_tools.py
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Nov 20, 2024
1 parent cecb211 commit 4365321
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 135 deletions.
108 changes: 81 additions & 27 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tqdm.auto import tqdm

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
import multiprocessing
import threading
from threadpoolctl import threadpool_limits

Expand Down Expand Up @@ -289,6 +289,8 @@ class ChunkRecordingExecutor:
If True, output is verbose
job_name : str, default: ""
Job name
progress_bar : bool, default: False
If True, a progress bar is printed to monitor the progress of the process
handle_returns : bool, default: False
If True, the function can return values
gather_func : None or callable, default: None
Expand All @@ -313,9 +315,8 @@ class ChunkRecordingExecutor:
Limit the number of thread per process using threadpoolctl modules.
This used only when n_jobs>1
If None, no limits.
progress_bar : bool, default: False
If True, a progress bar is printed to monitor the progress of the process
need_worker_index : bool, default False
If True then each worker will also have a "worker_index" injected in the local worker dict.
Returns
-------
Expand All @@ -342,6 +343,7 @@ def __init__(
mp_context=None,
job_name="",
max_threads_per_process=1,
need_worker_index=False,
):
self.recording = recording
self.func = func
Expand Down Expand Up @@ -377,6 +379,8 @@ def __init__(

self.pool_engine = pool_engine

self.need_worker_index = need_worker_index

if verbose:
chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize
total_memory = chunk_memory * self.n_jobs
Expand Down Expand Up @@ -412,9 +416,12 @@ def run(self, recording_slices=None):
if self.progress_bar:
recording_slices = tqdm(recording_slices, desc=self.job_name, total=len(recording_slices))

worker_ctx = self.init_func(*self.init_args)
worker_dict = self.init_func(*self.init_args)
if self.need_worker_index:
worker_dict["worker_index"] = 0

for segment_index, frame_start, frame_stop in recording_slices:
res = self.func(segment_index, frame_start, frame_stop, worker_ctx)
res = self.func(segment_index, frame_start, frame_stop, worker_dict)
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
Expand All @@ -425,12 +432,21 @@ def run(self, recording_slices=None):

if self.pool_engine == "process":

if self.need_worker_index:
lock = multiprocessing.Lock()
array_pid = multiprocessing.Array("i", n_jobs)
for i in range(n_jobs):
array_pid[i] = -1
else:
lock = None
array_pid = None

# parallel
with ProcessPoolExecutor(
max_workers=n_jobs,
initializer=process_worker_initializer,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process),
mp_context=multiprocessing.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, self.need_worker_index, lock, array_pid),
) as executor:
results = executor.map(process_function_wrapper, recording_slices)

Expand All @@ -444,29 +460,41 @@ def run(self, recording_slices=None):
self.gather_func(res)

elif self.pool_engine == "thread":
# only one shared context
# this is need to create a per worker local dict where the initializer will push the func wrapper
thread_local_data = threading.local()

# worker_dict = self.init_func(*self.init_args)
# thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process)
global _thread_started
_thread_started = 0

thread_data = threading.local()
if self.progress_bar:
# here the tqdm threading do not work (maybe collision) so we need to create a pbar
# before thread spawning
pbar = tqdm(desc=self.job_name, total=len(recording_slices))

if self.need_worker_index:
lock = threading.Lock()
thread_started = 0

with ThreadPoolExecutor(
max_workers=n_jobs,
initializer=thread_worker_initializer,
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_data),
initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_local_data, self.need_worker_index, lock),
) as executor:
recording_slices2 = [(thread_data, ) + args for args in recording_slices]
results = executor.map(thread_function_wrapper, recording_slices2)

if self.progress_bar:
results = tqdm(results, desc=self.job_name, total=len(recording_slices))

recording_slices2 = [(thread_local_data, ) + args for args in recording_slices]
results = executor.map(thread_function_wrapper, recording_slices2)

for res in results:
if self.progress_bar:
pbar.update(1)
if self.handle_returns:
returns.append(res)
if self.gather_func is not None:
self.gather_func(res)
if self.progress_bar:
pbar.close()
del pbar

else:
raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'")
Expand All @@ -476,6 +504,11 @@ def run(self, recording_slices=None):


class WorkerFuncWrapper:
"""
small wraper that handle:
* local worker_dict
* max_threads_per_process
"""
def __init__(self, func, worker_dict, max_threads_per_process):
self.func = func
self.worker_dict = worker_dict
Expand All @@ -498,36 +531,57 @@ def __call__(self, args):
global _process_func_wrapper


def process_worker_initializer(func, init_func, init_args, max_threads_per_process):
def process_worker_initializer(func, init_func, init_args, max_threads_per_process, need_worker_index, lock, array_pid):
global _process_func_wrapper
if max_threads_per_process is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
worker_dict = init_func(*init_args)

if need_worker_index:
child_process = multiprocessing.current_process()
lock.acquire()
worker_index = None
for i in range(len(array_pid)):
if array_pid[i] == -1:
worker_index = i
array_pid[i] = child_process.ident
break
worker_dict["worker_index"] = worker_index
lock.release()

_process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)

def process_function_wrapper(args):
global _process_func_wrapper
return _process_func_wrapper(args)

def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_data):

# use by thread at init
global _thread_started

def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_local_data, need_worker_index, lock):
if max_threads_per_process is None:
worker_dict = init_func(*init_args)
else:
with threadpool_limits(limits=max_threads_per_process):
worker_dict = init_func(*init_args)
thread_data._func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)
# print("ici", thread_data._func_wrapper)

def thread_function_wrapper(args):
thread_data = args[0]
args = args[1:]
# thread_data = threading.local()
# print("la", thread_data._func_wrapper)
return thread_data._func_wrapper(args)
if need_worker_index:
lock.acquire()
global _thread_started
worker_index = _thread_started
_thread_started += 1
worker_dict["worker_index"] = worker_index
lock.release()

thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process)

def thread_function_wrapper(args):
thread_local_data = args[0]
args = args[1:]
return thread_local_data.func_wrapper(args)



Expand Down
59 changes: 51 additions & 8 deletions src/spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import os

import time

from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs

from spikeinterface.core.job_tools import (
Expand Down Expand Up @@ -77,22 +79,22 @@ def test_ensure_chunk_size():
assert end_frame == recording.get_num_frames(segment_index=segment_index)


def func(segment_index, start_frame, end_frame, worker_ctx):
def func(segment_index, start_frame, end_frame, worker_dict):
import os
import time

#  print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid())
#  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid())
time.sleep(0.010)
# time.sleep(1.0)
return os.getpid()


def init_func(arg1, arg2, arg3):
worker_ctx = {}
worker_ctx["arg1"] = arg1
worker_ctx["arg2"] = arg2
worker_ctx["arg3"] = arg3
return worker_ctx
worker_dict = {}
worker_dict["arg1"] = arg1
worker_dict["arg2"] = arg2
worker_dict["arg3"] = arg3
return worker_dict


def test_ChunkRecordingExecutor():
Expand Down Expand Up @@ -235,10 +237,51 @@ def test_split_job_kwargs():
assert "other_param" not in job_kwargs and "n_jobs" in job_kwargs and "progress_bar" in job_kwargs




def func2(segment_index, start_frame, end_frame, worker_dict):
time.sleep(0.010)
# print(os.getpid(), worker_dict["worker_index"])
return worker_dict["worker_index"]


def init_func2():
# this leave time for other thread/process to start
time.sleep(0.010)
worker_dict = {}
return worker_dict


def test_worker_index():
recording = generate_recording(num_channels=2)
init_args = tuple()

for i in range(2):
# making this 2 times ensure to test that global variables are correctly reset
for pool_engine in ("process", "thread"):
processor = ChunkRecordingExecutor(
recording,
func2,
init_func2,
init_args,
progress_bar=False,
gather_func=None,
pool_engine=pool_engine,
n_jobs=2,
handle_returns=True,
chunk_duration="200ms",
need_worker_index=True
)
res = processor.run()
# we should have a mix of 0 and 1
assert 0 in res
assert 1 in res

if __name__ == "__main__":
# test_divide_segment_into_chunks()
# test_ensure_n_jobs()
# test_ensure_chunk_size()
test_ChunkRecordingExecutor()
# test_ChunkRecordingExecutor()
# test_fix_job_kwargs()
# test_split_job_kwargs()
test_worker_index()
16 changes: 12 additions & 4 deletions src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,25 @@ def test_estimate_templates_with_accumulator():
templates = estimate_templates_with_accumulator(
recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs
)
print(templates.shape)
# print(templates.shape)
assert templates.shape[0] == sorting.unit_ids.size
assert templates.shape[1] == nbefore + nafter
assert templates.shape[2] == recording.get_num_channels()

assert np.any(templates != 0)

job_kwargs = dict(n_jobs=1, progress_bar=True, chunk_duration="1s")
templates_loop = estimate_templates_with_accumulator(
recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs
)
np.testing.assert_almost_equal(templates, templates_loop, decimal=4)

# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# for unit_index, unit_id in enumerate(sorting.unit_ids):
# ax.plot(templates[unit_index, :, :].T.flatten())
# ax.plot(templates[unit_index, :, :].T.flatten())
# ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--")
# ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--")
# plt.show()


Expand Down Expand Up @@ -225,6 +233,6 @@ def test_estimate_templates():


if __name__ == "__main__":
test_waveform_tools()
# test_waveform_tools()
test_estimate_templates_with_accumulator()
test_estimate_templates()
# test_estimate_templates()
Loading

0 comments on commit 4365321

Please sign in to comment.