Skip to content

Commit

Permalink
improve process pool
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun Barrett committed Jun 11, 2024
1 parent b24c3d6 commit d9cb43d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 53 deletions.
2 changes: 1 addition & 1 deletion python/vmaf/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _run(asset_lock):
lock.release()
return result

self.results = parallel_map(_run, list_args, processes=processes, sleep_sec=0.1)
self.results = parallel_map(_run, list_args, processes=processes)
else:
self.results = list(map(self._run_on_asset, self.assets))

Expand Down
80 changes: 28 additions & 52 deletions python/vmaf/tools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,64 +306,40 @@ def index_and_value_of_min(l):
"""
return min(enumerate(l), key=lambda x: x[1])

_pm_return_dict = None
_pm_func = None
_pm_list_args = None

def parallel_map(func, list_args, processes=None, sleep_sec=0.01):
def _parallel_map_rt(idx):
_pm_return_dict[idx] = _pm_func(_pm_list_args[idx])

def parallel_map(func, list_args, processes=None):
"""
Build my own parallelized map function since multiprocessing's Process(),
or Pool.map() cannot meet my both needs:
1) be able to control the maximum number of processes in parallel
2) be able to take in non-picklable objects as arguments
Use multiprocessing.Pool to create a fast parallel map that doesn't pickle arguments
Note: only works on Unix! fork() used to propagate unpickleable data
"""

# get maximum number of active processes that can be used
max_active_procs = processes if processes is not None else multiprocessing.cpu_count()

# create shared dictionary
return_dict = multiprocessing.Manager().dict()

# define runner function
def func_wrapper(idx_args):
idx, args = idx_args
executor = func(args)
return_dict[idx] = executor

# add idx to args
list_idx_args = []
for idx, args in enumerate(list_args):
list_idx_args.append((idx, args))

procs = []
for idx_args in list_idx_args:
proc = multiprocessing.Process(target=func_wrapper, args=(idx_args,))
procs.append(proc)

waiting_procs = set(procs)
active_procs = set([])

# processing
while True:

# check if any procs in active_procs is done; if yes, remove them
for p in active_procs.copy():
if not p.is_alive():
active_procs.remove(p)

# check if we can add a proc to active_procs (add gradually one per loop)
if len(active_procs) < max_active_procs and len(waiting_procs) > 0:
# move one proc from waiting_procs to active_procs
p = waiting_procs.pop()
active_procs.add(p)
p.start()

# if both waiting_procs and active_procs are empty, can terminate
if len(waiting_procs) == 0 and len(active_procs) == 0:
break
context = multiprocessing

sleep(sleep_sec) # check every x sec
if getattr(multiprocessing, 'get_context', None) is None:
assert os.name == 'posix', "parallel_map() requires fork() support, but not running on Unix"
else:
assert 'fork' in multiprocessing.get_all_start_methods(), "parallel_map() requires fork() support"
context = multiprocessing.get_context('fork')

# finally, collect results
rets = list(map(lambda idx: return_dict[idx], range(len(list_args))))
return rets
# create shared dictionary
return_dict = context.Manager().dict()

def pool_init():
global _pm_func, _pm_list_args, _pm_return_dict
_pm_func = func
_pm_list_args = list_args
_pm_return_dict = return_dict

with context.Pool(processes, initializer=pool_init) as pool:
pool.map(_parallel_map_rt, range(len(list_args)))

return [return_dict[i] for i in range(len(list_args))]


def check_program_exist(program):
Expand Down

0 comments on commit d9cb43d

Please sign in to comment.