Skip to content

Commit

Permalink
Bug fix for map_parallel and refactor other stuff to change arg to n_…
Browse files Browse the repository at this point in the history
…workers
  • Loading branch information
RichieHakim committed Feb 1, 2024
1 parent 3192140 commit 5f99f86
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 38 deletions.
2 changes: 1 addition & 1 deletion bnpm/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def warp_sparse_image(
return warped_sparse_image

wsi_partial = partial(warp_sparse_image, remappingIdx=remappingIdx)
ims_sparse_out = parallel_helpers.map_parallel(func=wsi_partial, args=[ims_sparse,], method='multithreading', workers=n_workers, prog_bar=verbose)
ims_sparse_out = parallel_helpers.map_parallel(func=wsi_partial, args=[ims_sparse,], method='multithreading', n_workers=n_workers, prog_bar=verbose)
return ims_sparse_out


Expand Down
62 changes: 34 additions & 28 deletions bnpm/parallel_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def map_parallel(
func: Callable,
args: List[Any],
method: str = 'multithreading',
workers: int = -1,
n_workers: int = -1,
prog_bar: bool = True
) -> List[Any]:
"""
Expand Down Expand Up @@ -63,40 +63,28 @@ def map_parallel(
.. highlight::python
.. code-block::python
result = map_parallel(max, [[1,2,3,4],[5,6,7,8]], method='multiprocessing', workers=3)
result = map_parallel(max, [[1,2,3,4],[5,6,7,8]], method='multiprocessing', n_workers=3)
"""
if workers == -1:
workers = mp.cpu_count()

## Get number of arguments. If args is a generator, make None.
n_args = len(args[0]) if hasattr(args, '__len__') else None
if n_workers == -1:
n_workers = mp.cpu_count()

## Assert that args is a list
assert isinstance(args, list), "args must be a list"
## Assert that each element of args is an iterable
assert all([hasattr(arg, '__iter__') for arg in args]), "All elements of args must be iterable"

## Assert that each element has a length
assert all([hasattr(arg, '__len__') for arg in args]), "All elements of args must have a length"
## Get number of arguments. If args is a generator, make None.
n_args = len(args[0]) if hasattr(args, '__len__') else None
## Assert that all args are the same length
assert all([len(arg) == n_args for arg in args]), "All args must be the same length"

## Make indices
indices = np.arange(n_args)

def wrapper(*args_index):
"""
Wrapper function to catch exceptions.
Args:
*args_index (tuple):
Tuple of arguments to be passed to the function.
Should take the form of (arg1, arg2, ..., argN, index)
The last element is the index of the job.
"""
index = args_index[-1]
args = args_index[:-1]

try:
return func(*args)
except Exception as e:
raise ParallelExecutionError(index, e)
## Prepare args_map (input to map function)
args_map = [[func] * n_args, *args, indices]

if method == 'multithreading':
executor = ThreadPoolExecutor
Expand All @@ -109,13 +97,31 @@ def wrapper(*args_index):
# import joblib
# return joblib.Parallel(n_jobs=workers)(joblib.delayed(func)(arg) for arg in tqdm(args, total=n_args, disable=prog_bar!=True))
elif method == 'serial':
# return [func(*arg) for arg in tqdm(args, disable=prog_bar!=True)]
return list(tqdm(map(wrapper, *(args + [indices])), total=n_args, disable=prog_bar!=True))
return list(tqdm(map(_func_wrapper_helper, *args_map), total=n_args, disable=prog_bar!=True))
else:
raise ValueError(f"method {method} not recognized")

with executor(workers) as ex:
return list(tqdm(ex.map(wrapper, *(args + [indices])), total=n_args, disable=prog_bar!=True))
with executor(n_workers) as ex:
return list(tqdm(ex.map(_func_wrapper_helper, *args_map), total=n_args, disable=prog_bar!=True))
def _func_wrapper_helper(*func_args_index):
"""
Wrapper function to catch exceptions.
Args:
*func_args_index (tuple):
Tuple of arguments to be passed to the function.
Should take the form of (func, arg1, arg2, ..., argN, index)
The last element is the index of the job.
"""
func = func_args_index[0]
args = func_args_index[1:-1]
index = func_args_index[-1]

try:
return func(*args)
except Exception as e:
raise ParallelExecutionError(index, e)



def multiprocessing_pool_along_axis(x_in, function, n_workers=None, axis=0, **kwargs):
Expand Down
27 changes: 18 additions & 9 deletions bnpm/timeSeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def convFun_axis1(iter):

kernel = np.ascontiguousarray(kernel)
if axis==0:
output_list = parallel_helpers.map_parallel(convFun_axis0, [range(array.shape[1])], method='multithreading', workers=-1, prog_bar=verbose)
output_list = parallel_helpers.map_parallel(convFun_axis0, [range(array.shape[1])], method='multithreading', n_workers=-1, prog_bar=verbose)
if axis==1:
output_list = parallel_helpers.map_parallel(convFun_axis1, [range(array.shape[0])], method='multithreading', workers=-1, prog_bar=verbose)
output_list = parallel_helpers.map_parallel(convFun_axis1, [range(array.shape[0])], method='multithreading', n_workers=-1, prog_bar=verbose)

if verbose:
print(f'ThreadPool elapsed time : {round(time.time() - tic , 2)} s. Now unpacking list into array.')
Expand Down Expand Up @@ -280,23 +280,32 @@ def rolling_percentile_pd(
kwargs_rolling['closed'] = None

from functools import partial
_rolling_ptile_pd_helper_partial = partial(_rolling_ptile_pd_helper, win=int(window), ptile=ptile, kwargs_rolling=kwargs_rolling, interpolation=interpolation)
# _rolling_ptile_pd_helper_partial = partial(_rolling_ptile_pd_helper, win=int(window), ptile=ptile, kwargs_rolling=kwargs_rolling, interpolation=interpolation)
## Avoid using partial because it doesn't work with multiprocessing

if multiprocessing_pref:
from .parallel_helpers import map_parallel
from .indexing import make_batches
import multiprocessing as mp
n_batches = mp.cpu_count()
batches = make_batches(X, num_batches=n_batches)
batches = list(indexing.make_batches(X, num_batches=mp.cpu_count()))
n_batches = len(batches)
## Make args as a list of iterables, each with length n_batches. This is the format for map and map_parallel
args = [
batches,
[int(window)] * n_batches,
[ptile] * n_batches,
[kwargs_rolling] * n_batches,
[interpolation] * n_batches
]
output = map_parallel(
_rolling_ptile_pd_helper_partial,
[list(batches)],
_rolling_ptile_pd_helper,
args,
method='multiprocessing',
n_workers=-1,
prog_bar=prog_bar,
)
output = np.concatenate(output, axis=0)
else:
output = _rolling_ptile_pd_helper_partial(X)
output = _rolling_ptile_pd_helper(X, win=int(window), ptile=ptile, kwargs_rolling=kwargs_rolling, interpolation=interpolation)

return output
def _rolling_ptile_pd_helper(X, win, ptile, kwargs_rolling, interpolation='linear'):
Expand Down

0 comments on commit 5f99f86

Please sign in to comment.