Skip to content

Commit

Permalink
Add load_hacc_mahs.py module
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Nov 22, 2024
1 parent a75ed5a commit 39c462e
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
2 changes: 2 additions & 0 deletions diffmah/data_loaders/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
This subpackage contains coding for loading simulation data in the form appropriate for diffmah.
Modules may require additional dependencies beyond the diffmah package requirements.
Empty file.
108 changes: 108 additions & 0 deletions diffmah/data_loaders/load_hacc_mahs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Module loads mass assembly history data for diffmah
"""

import numpy as np

try:
from dsps.cosmology import flat_wcdm

HAS_DSPS = True
except ImportError:
HAS_DSPS = False

try:
from haccytrees import Simulation as HACCSim
from haccytrees import coretrees

HAS_HACCYTREES = True
except ImportError:
HAS_HACCYTREES = False

try:
from mpi4py import MPI

COMM = MPI.COMM_WORLD
except ImportError:
MPI = COMM = None


MASS_COLNAME = "infall_tree_node_mass"


def _scatter_nd(array, axis=0, comm=COMM, root=0):
"""Scatter n-dimensional array from root to all ranks
This function is taken from https://github.com/AlanPearl/diffopt
"""
ans: np.ndarray = np.array([])
if comm.rank == root:
splits = np.array_split(array, comm.size, axis=axis)
for i in range(comm.size):
if i == root:
ans = splits[i]
else:
comm.send(splits[i], dest=i)
else:
ans = comm.recv(source=root)
return ans


def _load_forest(fn_data, sim_name, chunknum, nchunks):
if not HAS_HACCYTREES:
raise ImportError("Must have haccytrees installed to use this data loader")

try:
sim = HACCSim.simulations[sim_name]
except KeyError:
sim = HACCSim.parse_config(sim_name)

zarr = sim.step2z(np.array(sim.cosmotools_steps))

forest_matrices = coretrees.corematrix_reader(
fn_data,
calculate_secondary_host_row=True,
nchunks=nchunks,
chunknum=chunknum,
simulation=sim,
)
return sim, forest_matrices, zarr


def load_mahs(fn_data, sim_name, chunknum, nchunks, mass_colname=MASS_COLNAME):

sim, forest_matrices, zarr = _load_forest(fn_data, sim_name, chunknum, nchunks)
mahs = forest_matrices[mass_colname]

# Ensure the target MAHs are cumulative peak masses
mahs = np.maximum.accumulate(mahs, axis=1)

if not HAS_DSPS:
raise ImportError("Must have dsps installed to use this data loader")

cosmo_dsps = flat_wcdm.CosmoParams(
*(sim.cosmo.Omega_m, sim.cosmo.w0, sim.cosmo.wa, sim.cosmo.h)
)

tarr = flat_wcdm.age_at_z(zarr, *cosmo_dsps)

return tarr, mahs


def load_mahs_per_rank(
fn_data, sim_name, chunknum, nchunks, mass_colname=MASS_COLNAME, comm=None
):
if comm is None:
comm = MPI.COMM_WORLD

if comm.rank == 0:
tarr, mahs = load_mahs(
fn_data, sim_name, chunknum, nchunks, mass_colname=mass_colname
)
else:
tarr = None
mahs = None
mahs_for_rank = _scatter_nd(mahs, axis=0, comm=comm, root=0)
tarr = comm.bcast(tarr, root=0)

return tarr, mahs_for_rank

0 comments on commit 39c462e

Please sign in to comment.