Skip to content

Commit

Permalink
Load MAHs on single rank and scatter to others
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanPearl committed Nov 14, 2024
1 parent 5c2ebe4 commit 498ef0c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
27 changes: 8 additions & 19 deletions scripts/hacc_discovery_sims_diffmah_fitter_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from diffmah import fitting_helpers as cfh
from mpi4py import MPI

from load_w0wa_cores import NCHUNKS, NUM_SUBVOLS_DISCOVERY, load_mahs
from load_w0wa_cores import NCHUNKS, NUM_SUBVOLS_DISCOVERY, load_mahs_per_rank

TMP_OUTPAT = "tmp_mah_fits_rank_{0}.dat"

Expand Down Expand Up @@ -90,24 +90,11 @@
comm.Barrier()
ichunk_start = time()

tarr, mahs = load_mahs(fn_data, fn_cfg, chunknum, nchunks=nchunks)

if rank == 0:
print("Number of halos in chunk = {}".format(mahs.shape[0]))

# Ensure the target MAHs are cumulative peak masses
mahs = np.maximum.accumulate(mahs, axis=1)

# Get data for rank
if args.test:
nhalos_tot = nranks * 5
else:
nhalos_tot = mahs.shape[0]
_a = np.arange(0, nhalos_tot).astype("i8")
indx = np.array_split(_a, nranks)[rank]

mahs_for_rank = mahs[indx]
tarr, mahs_for_rank = load_mahs_per_rank(
fn_data, fn_cfg, chunknum, nchunks=nchunks, comm=MPI.COMM_WORLD
)
nhalos_for_rank = mahs_for_rank.shape[0]
nhalos_tot = comm.reduce(nhalos_for_rank, op=MPI.SUM)

chunknum_str = f"{chunknum:0{nchar_chunks}d}"
outbase_chunk = f"subvol_{subvol_str}_chunk_{chunknum_str}"
Expand All @@ -129,7 +116,9 @@

msg = "\n\nWallclock runtime to fit {0} halos with {1} ranks = {2:.1f} seconds\n\n"
if rank == 0:
print("\nFinished with subvolume {}".format(isubvol))
print("\nFinished with subvolume {} chunk {}".format(
isubvol, chunknum
))
runtime = ichunk_end - ichunk_start
print(msg.format(nhalos_tot, nranks, runtime))

Expand Down
30 changes: 29 additions & 1 deletion scripts/load_w0wa_cores.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from haccytrees import Simulation as HACCSim
from haccytrees import coretrees

from mpi4py import MPI
from diffopt.multigrad.util import scatter_nd

import desi_cosmo

TASSO_DRN_DESI = "/Users/aphearin/work/DATA/DESI_W0WA"
MASS_COLNAME = "infall_tree_node_mass"

NCHUNKS = 500
NCHUNKS = 20
NUM_SUBVOLS_DISCOVERY = 96


Expand Down Expand Up @@ -43,3 +46,28 @@ def load_mahs(fn_data, fn_cfg, chunknum, nchunks=NCHUNKS, mass_colname=MASS_COLN
tarr = flat_wcdm.age_at_z(zarr, *cosmo_params)

return tarr, mahs


def load_mahs_per_rank(fn_data, fn_cfg, chunknum, nchunks=NCHUNKS,
mass_colname=MASS_COLNAME, comm=None):
if comm is None:
comm = MPI.COMM_WORLD

if comm.rank == 0:
tarr, mahs = load_mahs(
fn_data, fn_cfg, chunknum, nchunks=nchunks,
mass_colname=mass_colname
)

# Ensure the target MAHs are cumulative peak masses
mahs = np.maximum.accumulate(mahs, axis=1)
else:
tarr = None
mahs = None
mahs_for_rank = scatter_nd(mahs, axis=0, comm=comm, root=0)
tarr = comm.bcast(tarr, root=0)

if comm.rank == 0:
print("Number of halos in chunk = {}".format(mahs.shape[0]))

return tarr, mahs_for_rank

0 comments on commit 498ef0c

Please sign in to comment.