diff --git a/python/vmaf/core/executor.py b/python/vmaf/core/executor.py index 52dde844e..d2da1375d 100644 --- a/python/vmaf/core/executor.py +++ b/python/vmaf/core/executor.py @@ -182,7 +182,7 @@ def _run(asset_lock): lock.release() return result - self.results = parallel_map(_run, list_args, processes=processes, sleep_sec=0.1) + self.results = parallel_map(_run, list_args, processes=processes) else: self.results = list(map(self._run_on_asset, self.assets)) diff --git a/python/vmaf/tools/misc.py b/python/vmaf/tools/misc.py index dc82f07d5..e469045eb 100644 --- a/python/vmaf/tools/misc.py +++ b/python/vmaf/tools/misc.py @@ -306,64 +306,40 @@ def index_and_value_of_min(l): """ return min(enumerate(l), key=lambda x: x[1]) +_pm_return_dict = None +_pm_func = None +_pm_list_args = None -def parallel_map(func, list_args, processes=None, sleep_sec=0.01): +def _parallel_map_rt(idx): + _pm_return_dict[idx] = _pm_func(_pm_list_args[idx]) + +def parallel_map(func, list_args, processes=None): """ - Build my own parallelized map function since multiprocessing's Process(), - or Pool.map() cannot meet my both needs: - 1) be able to control the maximum number of processes in parallel - 2) be able to take in non-picklable objects as arguments + Use multiprocessing.Pool to create a fast parallel map that doesn't pickle arguments + Note: only works on Unix! fork() used to propagate unpickleable data """ - # get maximum number of active processes that can be used - max_active_procs = processes if processes is not None else multiprocessing.cpu_count() - - # create shared dictionary - return_dict = multiprocessing.Manager().dict() - - # define runner function - def func_wrapper(idx_args): - idx, args = idx_args - executor = func(args) - return_dict[idx] = executor - - # add idx to args - list_idx_args = [] - for idx, args in enumerate(list_args): - list_idx_args.append((idx, args)) - - procs = [] - for idx_args in list_idx_args: - proc = multiprocessing.Process(target=func_wrapper, args=(idx_args,)) - procs.append(proc) - - waiting_procs = set(procs) - active_procs = set([]) - - # processing - while True: - - # check if any procs in active_procs is done; if yes, remove them - for p in active_procs.copy(): - if not p.is_alive(): - active_procs.remove(p) - - # check if we can add a proc to active_procs (add gradually one per loop) - if len(active_procs) < max_active_procs and len(waiting_procs) > 0: - # move one proc from waiting_procs to active_procs - p = waiting_procs.pop() - active_procs.add(p) - p.start() - - # if both waiting_procs and active_procs are empty, can terminate - if len(waiting_procs) == 0 and len(active_procs) == 0: - break + context = multiprocessing - sleep(sleep_sec) # check every x sec + if getattr(multiprocessing, 'get_context', None) is None: + assert os.name == 'posix', "parallel_map() requires fork() support, but not running on Unix" + else: + assert 'fork' in multiprocessing.get_all_start_methods(), "parallel_map() requires fork() support" + context = multiprocessing.get_context('fork') - # finally, collect results - rets = list(map(lambda idx: return_dict[idx], range(len(list_args)))) - return rets + # create shared dictionary + return_dict = context.Manager().dict() + + def pool_init(): + global _pm_func, _pm_list_args, _pm_return_dict + _pm_func = func + _pm_list_args = list_args + _pm_return_dict = return_dict + + with context.Pool(processes, initializer=pool_init) as pool: + pool.map(_parallel_map_rt, range(len(list_args))) + + return [return_dict[i] for i in range(len(list_args))] def check_program_exist(program):