From 5f99f86a9367d24546a6f205fea74b9ead132c9e Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Thu, 1 Feb 2024 16:11:44 -0500 Subject: [PATCH] Bug fix for map_parallel and refactor other stuff to change arg to n_workers --- bnpm/image_processing.py | 2 +- bnpm/parallel_helpers.py | 62 ++++++++++++++++++++++------------------ bnpm/timeSeries.py | 27 +++++++++++------ 3 files changed, 53 insertions(+), 38 deletions(-) diff --git a/bnpm/image_processing.py b/bnpm/image_processing.py index d3445b3..be2175c 100644 --- a/bnpm/image_processing.py +++ b/bnpm/image_processing.py @@ -476,7 +476,7 @@ def warp_sparse_image( return warped_sparse_image wsi_partial = partial(warp_sparse_image, remappingIdx=remappingIdx) - ims_sparse_out = parallel_helpers.map_parallel(func=wsi_partial, args=[ims_sparse,], method='multithreading', workers=n_workers, prog_bar=verbose) + ims_sparse_out = parallel_helpers.map_parallel(func=wsi_partial, args=[ims_sparse,], method='multithreading', n_workers=n_workers, prog_bar=verbose) return ims_sparse_out diff --git a/bnpm/parallel_helpers.py b/bnpm/parallel_helpers.py index d8127a4..002a9b1 100644 --- a/bnpm/parallel_helpers.py +++ b/bnpm/parallel_helpers.py @@ -28,7 +28,7 @@ def map_parallel( func: Callable, args: List[Any], method: str = 'multithreading', - workers: int = -1, + n_workers: int = -1, prog_bar: bool = True ) -> List[Any]: """ @@ -63,40 +63,28 @@ def map_parallel( .. highlight::python .. code-block::python - result = map_parallel(max, [[1,2,3,4],[5,6,7,8]], method='multiprocessing', workers=3) + result = map_parallel(max, [[1,2,3,4],[5,6,7,8]], method='multiprocessing', n_workers=3) """ - if workers == -1: - workers = mp.cpu_count() - - ## Get number of arguments. If args is a generator, make None. - n_args = len(args[0]) if hasattr(args, '__len__') else None + if n_workers == -1: + n_workers = mp.cpu_count() ## Assert that args is a list assert isinstance(args, list), "args must be a list" + ## Assert that each element of args is an iterable + assert all([hasattr(arg, '__iter__') for arg in args]), "All elements of args must be iterable" + ## Assert that each element has a length + assert all([hasattr(arg, '__len__') for arg in args]), "All elements of args must have a length" + ## Get number of arguments. If args is a generator, make None. + n_args = len(args[0]) if hasattr(args, '__len__') else None ## Assert that all args are the same length assert all([len(arg) == n_args for arg in args]), "All args must be the same length" ## Make indices indices = np.arange(n_args) - def wrapper(*args_index): - """ - Wrapper function to catch exceptions. - - Args: - *args_index (tuple): - Tuple of arguments to be passed to the function. - Should take the form of (arg1, arg2, ..., argN, index) - The last element is the index of the job. - """ - index = args_index[-1] - args = args_index[:-1] - - try: - return func(*args) - except Exception as e: - raise ParallelExecutionError(index, e) + ## Prepare args_map (input to map function) + args_map = [[func] * n_args, *args, indices] if method == 'multithreading': executor = ThreadPoolExecutor @@ -109,13 +97,31 @@ def wrapper(*args_index): # import joblib # return joblib.Parallel(n_jobs=workers)(joblib.delayed(func)(arg) for arg in tqdm(args, total=n_args, disable=prog_bar!=True)) elif method == 'serial': - # return [func(*arg) for arg in tqdm(args, disable=prog_bar!=True)] - return list(tqdm(map(wrapper, *(args + [indices])), total=n_args, disable=prog_bar!=True)) + return list(tqdm(map(_func_wrapper_helper, *args_map), total=n_args, disable=prog_bar!=True)) else: raise ValueError(f"method {method} not recognized") - with executor(workers) as ex: - return list(tqdm(ex.map(wrapper, *(args + [indices])), total=n_args, disable=prog_bar!=True)) + with executor(n_workers) as ex: + return list(tqdm(ex.map(_func_wrapper_helper, *args_map), total=n_args, disable=prog_bar!=True)) +def _func_wrapper_helper(*func_args_index): + """ + Wrapper function to catch exceptions. + + Args: + *func_args_index (tuple): + Tuple of arguments to be passed to the function. + Should take the form of (func, arg1, arg2, ..., argN, index) + The last element is the index of the job. + """ + func = func_args_index[0] + args = func_args_index[1:-1] + index = func_args_index[-1] + + try: + return func(*args) + except Exception as e: + raise ParallelExecutionError(index, e) + def multiprocessing_pool_along_axis(x_in, function, n_workers=None, axis=0, **kwargs): diff --git a/bnpm/timeSeries.py b/bnpm/timeSeries.py index f753cb5..6b08aa2 100644 --- a/bnpm/timeSeries.py +++ b/bnpm/timeSeries.py @@ -70,9 +70,9 @@ def convFun_axis1(iter): kernel = np.ascontiguousarray(kernel) if axis==0: - output_list = parallel_helpers.map_parallel(convFun_axis0, [range(array.shape[1])], method='multithreading', workers=-1, prog_bar=verbose) + output_list = parallel_helpers.map_parallel(convFun_axis0, [range(array.shape[1])], method='multithreading', n_workers=-1, prog_bar=verbose) if axis==1: - output_list = parallel_helpers.map_parallel(convFun_axis1, [range(array.shape[0])], method='multithreading', workers=-1, prog_bar=verbose) + output_list = parallel_helpers.map_parallel(convFun_axis1, [range(array.shape[0])], method='multithreading', n_workers=-1, prog_bar=verbose) if verbose: print(f'ThreadPool elapsed time : {round(time.time() - tic , 2)} s. Now unpacking list into array.') @@ -280,23 +280,32 @@ def rolling_percentile_pd( kwargs_rolling['closed'] = None from functools import partial - _rolling_ptile_pd_helper_partial = partial(_rolling_ptile_pd_helper, win=int(window), ptile=ptile, kwargs_rolling=kwargs_rolling, interpolation=interpolation) + # _rolling_ptile_pd_helper_partial = partial(_rolling_ptile_pd_helper, win=int(window), ptile=ptile, kwargs_rolling=kwargs_rolling, interpolation=interpolation) + ## Avoid using partial because it doesn't work with multiprocessing if multiprocessing_pref: from .parallel_helpers import map_parallel - from .indexing import make_batches import multiprocessing as mp - n_batches = mp.cpu_count() - batches = make_batches(X, num_batches=n_batches) + batches = list(indexing.make_batches(X, num_batches=mp.cpu_count())) + n_batches = len(batches) + ## Make args as a list of iterables, each with length n_batches. This is the format for map and map_parallel + args = [ + batches, + [int(window)] * n_batches, + [ptile] * n_batches, + [kwargs_rolling] * n_batches, + [interpolation] * n_batches + ] output = map_parallel( - _rolling_ptile_pd_helper_partial, - [list(batches)], + _rolling_ptile_pd_helper, + args, method='multiprocessing', + n_workers=-1, prog_bar=prog_bar, ) output = np.concatenate(output, axis=0) else: - output = _rolling_ptile_pd_helper_partial(X) + output = _rolling_ptile_pd_helper(X, win=int(window), ptile=ptile, kwargs_rolling=kwargs_rolling, interpolation=interpolation) return output def _rolling_ptile_pd_helper(X, win, ptile, kwargs_rolling, interpolation='linear'):