Skip to content

Commit

Permalink
improved multiprocessing support
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Lange committed Apr 18, 2022
1 parent 73778b1 commit 62f2f5b
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 67 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ from tabcorr import TabCorr
rp_bins = np.logspace(-1, 1, 20)
halocat = CachedHaloCatalog(simname='bolplanck')
halotab = TabCorr.tabulate(halocat, wp, rp_bins, pi_max=40,
period=halocat.Lbox)
halotab = TabCorr.tabulate(halocat, wp, rp_bins, pi_max=40)
# We can save the result for later use.
halotab.write('bolplanck.hdf5')
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from setuptools import setup

setup(name='tabcorr',
version='0.7.2',
version='0.8.0',
description='Tabulated correlation functions for halotools',
url='https://github.com/johannesulf/TabCorr',
author='Johannes Ulf Lange',
author_email='[email protected]',
author='Johannes U. Lange',
author_email='[email protected]',
packages=['tabcorr'],
install_requires=['numpy', 'scipy', 'astropy', 'h5py', 'tqdm'],
zip_safe=False)
153 changes: 91 additions & 62 deletions tabcorr/tabcorr.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,122 @@
import h5py
import tqdm
import time
import itertools
from multiprocessing import Pool
from functools import partial
import numpy as np
from queue import Empty
from random import shuffle
from multiprocessing import Process, Queue
from scipy.spatial import Delaunay
from astropy.table import Table, vstack
from halotools.sim_manager import sim_defaults
from halotools.empirical_models import HodModelFactory, model_defaults
from halotools.empirical_models import TrivialPhaseSpace, Zheng07Cens
from halotools.empirical_models import NFWPhaseSpace, Zheng07Sats
from halotools.mock_observables import return_xyz_formatted_array
from halotools.sim_manager import sim_defaults
from halotools.utils import crossmatch
from halotools.utils.table_utils import compute_conditional_percentiles


def print_progress(progress):
percent = "{0:.1f}".format(100 * progress)
bar = '=' * int(50 * progress) + '>' + ' ' * int(50 * (1 - progress))
print('\rProgress: |{0}| {1}%'.format(bar, percent), end='\r')
if progress == 1.0:
print()

def compute_tpcf(mode, pos, tpcf, period, tpcf_args, tpcf_kwargs,
input_queue, output_queue):

def compute_tpcf_matrix(mode, pos, tpcf, period, tpcf_args, tpcf_kwargs,
combinations, num_threads=1, verbose=False):
while True:
try:
i = input_queue.get(block=True, timeout=0.1)
if mode == 'auto':
i_1, i_2 = i
xi = tpcf(pos[i_1], *tpcf_args, sample2=pos[i_2] if i_1 != i_2
else None, do_auto=(i_1 == i_2),
do_cross=(i_1 != i_2), period=period, **tpcf_kwargs)
else:
xi = tpcf(pos[i], *tpcf_args, period=period, **tpcf_kwargs)

if num_threads > 1:
output_queue.put((i, xi))

compute_tpcf_matrix_partial = partial(
compute_tpcf_matrix, mode, pos, tpcf, period, tpcf_args,
tpcf_kwargs)
combinations = list(combinations)
combinations = [combinations[i::num_threads] for i in
range(num_threads)]
except Empty:
break

with Pool(num_threads) as pool:
result = pool.map(compute_tpcf_matrix_partial, combinations)
tpcf_matrix = np.sum([r[0] for r in result], axis=0)
tpcf_shape = result[0][1]
return tpcf_matrix, tpcf_shape

n = np.array([len(p) for p in pos])
n_done = 0
def compute_tpcf_matrix(mode, pos, tpcf, period, tpcf_args, tpcf_kwargs,
num_threads=1, verbose=False):

if mode == 'auto':
for i, (s1, s2) in enumerate(combinations):
tpcf_matrix = None

if len(pos[s1]) * len(pos[s2]) == 0:
continue
input_queue = Queue()
output_queue = Queue()

if verbose:
n_done += (n[s1] * n[s2] * (2 if s1 != s2 else 1))
print_progress(n_done / np.sum(n)**2)
if mode == 'auto':
tasks = itertools.combinations_with_replacement(range(len(pos)), 2)
else:
tasks = range(len(pos))

xi = tpcf(pos[s1], *tpcf_args,
sample2=pos[s2] if s1 != s2 else None,
do_auto=(s1 == s2), do_cross=(s1 != s2),
period=period, **tpcf_kwargs)
tasks = list(tasks)
shuffle(tasks)

if i == 0:
tpcf_matrix = np.zeros((len(xi.ravel()), len(pos), len(pos)))
tpcf_shape = xi.shape
n_tot = 0
n = 0

tpcf_matrix[:, s1, s2] += xi.ravel()
tpcf_matrix[:, s2, s1] = tpcf_matrix[:, s1, s2]
for task in tasks:
if mode == 'auto':
i_1, i_2 = task
if len(pos[i_1]) * len(pos[i_2]) > 0:
n_tot += len(pos[i_1]) * len(pos[i_2])
input_queue.put(task)
else:
i = task
if len(pos[i]) > 0:
n_tot += len(pos[i_1]) * len(pos[i_2])
input_queue.put(task)

if verbose:
pbar = tqdm.tqdm(
total=n_tot, bar_format='{l_bar}{bar}[{elapsed}<{remaining}]',
smoothing=0)

p_list = []
for i in range(num_threads):
p = Process(target=compute_tpcf, args=(
mode, pos, tpcf, period, tpcf_args, tpcf_kwargs, input_queue,
output_queue))
p.start()
p_list.append(p)

while n < n_tot:
try:
task, xi = output_queue.get(False)

elif mode == 'cross':
for i, s in enumerate(combinations):
if tpcf_matrix is None:
if mode == 'auto':
tpcf_matrix = np.zeros(
(len(xi.ravel()), len(pos), len(pos)))
else:
tpcf_matrix = np.zeros((len(xi.ravel()), len(pos)))

if len(pos[s]) == 0:
continue
if mode == 'auto':
i_1, i_2 = task
tpcf_matrix[:, i_1, i_2] += xi.ravel()
tpcf_matrix[:, i_2, i_1] = tpcf_matrix[:, i_1, i_2]
n += len(pos[i_1]) * len(pos[i_2])
if verbose:
pbar.update(len(pos[i_1]) * len(pos[i_2]))
else:
i = task
tpcf_matrix[:, i] += xi.ravel()
n += len(pos[i])
if verbose:
pbar.update(len(pos[i]))

if verbose:
n_done += n[s]
print_progress(n_done / np.sum(n))
tpcf_shape = xi.shape

xi = tpcf(pos[s], *tpcf_args, period=period, **tpcf_kwargs)
except Empty:
time.sleep(0.001)
pass

if i == 0:
tpcf_matrix = np.zeros((len(xi.ravel()), len(pos)))
tpcf_shape = xi.shape
for p in p_list:
p.join()

tpcf_matrix[:, s] = xi.ravel()
if verbose:
pbar.close()

return tpcf_matrix, tpcf_shape

Expand Down Expand Up @@ -415,23 +450,17 @@ def tabulate(cls, halocat, tpcf, *tpcf_args,

pos_bin.append(pos)

if mode == 'auto':
combinations = itertools.combinations_with_replacement(
range(len(halotab.gal_type)), 2)
else:
combinations = range(len(halotab.gal_type))

if xyz == 'xyz':
tpcf_matrix, tpcf_shape = compute_tpcf_matrix(
mode, pos_bin, tpcf, period, tpcf_args, tpcf_kwargs,
combinations, num_threads=num_threads, verbose=verbose)
num_threads=num_threads, verbose=verbose)

if not project_xyz or mode == 'cross':
break
elif xyz != 'xyz':
tpcf_matrix += compute_tpcf_matrix(
mode, pos_bin, tpcf, period, tpcf_args, tpcf_kwargs,
combinations, num_threads=num_threads, verbose=verbose)[0]
num_threads=num_threads, verbose=verbose)[0]

if project_xyz and mode == 'auto':
tpcf_matrix /= 3.0
Expand Down

0 comments on commit 62f2f5b

Please sign in to comment.