Skip to content

Commit

Permalink
Merge pull request #155 from ArgonneCPAC/mc_refactor
Browse files Browse the repository at this point in the history
Refactor DiffmahPop Monte Carlo generators
  • Loading branch information
aphearin authored Dec 4, 2024
2 parents 01edb30 + 20ec6ad commit fca9835
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 11 deletions.
110 changes: 108 additions & 2 deletions diffmah/diffmahpop_kernels/mc_bimod_cens.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
return mah_params_early, mah_params_late, frac_early_cens


@jjit
def _mean_diffmah_params_t_peak(diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key):
mah_params_early = _mean_diffmah_params_early_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

mah_params_late = _mean_diffmah_params_late_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

frac_early_cens = _frac_early_cens_kern(diffmahpop_params, lgm_obs, t_obs)

return mah_params_early, mah_params_late, frac_early_cens


@jjit
def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
t_0 = 10**lgt0
Expand Down Expand Up @@ -77,6 +92,35 @@ def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
return mah_params


@jjit
def _mean_diffmah_params_early_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
):
model_params = get_component_model_params(diffmahpop_params)
(
tp_pdf_cens_params,
tp_pdf_sats_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
fec_params,
cov_params,
) = model_params

tpc_key, ran_key = jran.split(ran_key, 2)

logm0 = _pred_logm0_kern_early(logm0_params, lgm_obs, t_obs, t_peak)
logtc = _pred_logtc_early(logtc_params, lgm_obs, t_obs, t_peak)
early_index = _pred_early_index_early(early_index_params, lgm_obs, t_obs, t_peak)
late_index = _pred_late_index_early(late_index_params, lgm_obs)

mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak)

return mah_params


@jjit
def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
t_0 = 10**lgt0
Expand Down Expand Up @@ -108,8 +152,45 @@ def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):


@jjit
def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
_res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
def _mean_diffmah_params_late_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
):
model_params = get_component_model_params(diffmahpop_params)
(
tp_pdf_cens_params,
tp_pdf_sats_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
fec_params,
cov_params,
) = model_params

tpc_key, ran_key = jran.split(ran_key, 2)

logm0 = _pred_logm0_kern_late(logm0_params, lgm_obs, t_obs, t_peak)
logtc = _pred_logtc_late(logtc_params, lgm_obs, t_obs, t_peak)
early_index = _pred_early_index_late(early_index_params, lgm_obs, t_obs, t_peak)
late_index = _pred_late_index_late(late_index_params, lgm_obs)

mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak)

return mah_params


@jjit
def mc_diffmah_params_singlecen(
diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0, t_peak=None
):
if t_peak is None:
_res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
else:
_res = _mean_diffmah_params_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

(
mean_mah_params_early,
mean_mah_params_late,
Expand Down Expand Up @@ -159,6 +240,13 @@ def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0
_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V))


_P1 = (None, 0, 0, 0, None, None)
mc_diffmah_params_cenpop_kern1 = jjit(vmap(mc_diffmah_params_singlecen, in_axes=_P1))

_P2 = (None, 0, 0, 0, None, 0)
mc_diffmah_params_cenpop_kern2 = jjit(vmap(mc_diffmah_params_singlecen, in_axes=_P2))


@partial(jjit, static_argnames=["n_mc"])
def _mc_diffmah_halo_sample(
diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN
Expand All @@ -170,6 +258,24 @@ def _mc_diffmah_halo_sample(
)


@jjit
def mc_diffmah_cenpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0, t_peak=None):
""""""
early_late_key, ran_key = jran.split(ran_key, 2)
ran_keys = jran.split(ran_key, lgm_obs.size)
args = diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0, t_peak
if t_peak is None:
_res = mc_diffmah_params_cenpop_kern1(*args)
else:
_res = mc_diffmah_params_cenpop_kern2(*args)
mah_params_early, mah_params_late, frac_early_cens = _res
uran = jran.uniform(early_late_key, shape=frac_early_cens.shape)
mc_early = uran < frac_early_cens
_p = [jnp.where(mc_early, x, y) for x, y in zip(mah_params_early, mah_params_late)]
mah_params = DEFAULT_MAH_PARAMS._make(_p)
return mah_params, mah_params_early, mah_params_late, frac_early_cens, mc_early


@jjit
def mc_cenpop(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
""""""
Expand Down
120 changes: 113 additions & 7 deletions diffmah/diffmahpop_kernels/mc_bimod_sats.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
return mah_params_early, mah_params_late, frac_early_cens


@jjit
def _mean_diffmah_params_t_peak(diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key):
mah_params_early = _mean_diffmah_params_early_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

mah_params_late = _mean_diffmah_params_late_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

frac_early_cens = _frac_early_cens_kern(diffmahpop_params, lgm_obs, t_obs)

return mah_params_early, mah_params_late, frac_early_cens


@jjit
def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
model_params = get_component_model_params(diffmahpop_params)
Expand Down Expand Up @@ -77,6 +92,35 @@ def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
return mah_params


@jjit
def _mean_diffmah_params_early_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
):
model_params = get_component_model_params(diffmahpop_params)
(
tp_pdf_cens_params,
tp_pdf_sats_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
fec_params,
cov_params,
) = model_params

tpc_key, ran_key = jran.split(ran_key, 2)

logm0 = _pred_logm0_kern_early(logm0_sats_params, lgm_obs, t_obs, t_peak)
logtc = _pred_logtc_early(logtc_params, lgm_obs, t_obs, t_peak)
early_index = _pred_early_index_early(early_index_params, lgm_obs, t_obs, t_peak)
late_index = _pred_late_index_early(late_index_params, lgm_obs)

mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak)

return mah_params


@jjit
def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
model_params = get_component_model_params(diffmahpop_params)
Expand Down Expand Up @@ -108,8 +152,45 @@ def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):


@jjit
def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
_res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
def _mean_diffmah_params_late_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
):
model_params = get_component_model_params(diffmahpop_params)
(
tp_pdf_cens_params,
tp_pdf_sats_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
fec_params,
cov_params,
) = model_params

tpc_key, ran_key = jran.split(ran_key, 2)

logm0 = _pred_logm0_kern_late(logm0_sats_params, lgm_obs, t_obs, t_peak)
logtc = _pred_logtc_late(logtc_params, lgm_obs, t_obs, t_peak)
early_index = _pred_early_index_late(early_index_params, lgm_obs, t_obs, t_peak)
late_index = _pred_late_index_late(late_index_params, lgm_obs)

mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak)

return mah_params


@jjit
def mc_diffmah_params_singlesat(
diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0, t_peak=None
):
if t_peak is None:
_res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
else:
_res = _mean_diffmah_params_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

(mean_mah_params_early, mean_mah_params_late, frac_early_cens) = _res

mean_mah_u_params_early = get_unbounded_mah_params(mean_mah_params_early)
Expand Down Expand Up @@ -139,8 +220,8 @@ def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0


@jjit
def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
_res = mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
def _mc_diffmah_singlesat(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
_res = mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
mah_params_early, mah_params_late, frac_early_cens = _res
dmhdt_early, log_mah_early = mah_singlehalo(mah_params_early, tarr, lgt0)
dmhdt_late, log_mah_late = mah_singlehalo(mah_params_late, tarr, lgt0)
Expand All @@ -152,7 +233,7 @@ def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0


_V = (None, None, 0, 0, 0, None)
_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V))
_mc_diffmah_singlesat_vmap_kern = jjit(vmap(_mc_diffmah_singlesat, in_axes=_V))


@partial(jjit, static_argnames=["n_mc"])
Expand All @@ -161,19 +242,44 @@ def _mc_diffmah_halo_sample(
):
zz = jnp.zeros(n_mc)
ran_keys = jran.split(ran_key, n_mc)
return _mc_diffmah_singlecen_vmap_kern(
return _mc_diffmah_singlesat_vmap_kern(
diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0
)


_P1 = (None, 0, 0, 0, None, None)
mc_diffmah_params_satpop_kern1 = jjit(vmap(mc_diffmah_params_singlesat, in_axes=_P1))

_P2 = (None, 0, 0, 0, None, 0)
mc_diffmah_params_satpop_kern2 = jjit(vmap(mc_diffmah_params_singlesat, in_axes=_P2))


@jjit
def mc_diffmah_satpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0, t_peak=None):
""""""
early_late_key, ran_key = jran.split(ran_key, 2)
ran_keys = jran.split(ran_key, lgm_obs.size)
args = diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0, t_peak
if t_peak is None:
_res = mc_diffmah_params_satpop_kern1(*args)
else:
_res = mc_diffmah_params_satpop_kern2(*args)
mah_params_early, mah_params_late, frac_early_cens = _res
uran = jran.uniform(early_late_key, shape=frac_early_cens.shape)
mc_early = uran < frac_early_cens
_p = [jnp.where(mc_early, x, y) for x, y in zip(mah_params_early, mah_params_late)]
mah_params = DEFAULT_MAH_PARAMS._make(_p)
return mah_params, mah_params_early, mah_params_late, frac_early_cens, mc_early


@jjit
def mc_satpop(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
""""""
n_mc = lgm_obs.shape[0]
ran_keys = jran.split(ran_key, n_mc + 1)
dmah_keys = ran_keys[:-1]
uran_key = ran_keys[-1]
_res = _mc_diffmah_singlecen_vmap_kern(
_res = _mc_diffmah_singlesat_vmap_kern(
diffmahpop_params, tarr, lgm_obs, t_obs, dmah_keys, lgt0
)
p_e, dmhdt_early, log_mah_early = _res[0:3]
Expand Down
53 changes: 53 additions & 0 deletions diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_bimod_cens.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,59 @@ def test_mc_diffmah_params_singlecen():
assert np.all(np.isfinite(mah_params_l.logtc))


def test_mc_diffmah_params_singlecen_agrees_with_fixed_t_peak_version():
ran_key = jran.key(0)
t_0 = 13.0
lgt0 = np.log10(t_0)
t_obs = 10.0
lgmarr = np.linspace(10, 15, 20)
for lgm_obs in lgmarr:
args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0)
_res = mcdpk.mc_diffmah_params_singlecen(*args)
mah_params_e, mah_params_l, frac_early = _res

_res2 = mcdpk.mc_diffmah_params_singlecen(
*args,
t_peak=mah_params_e.t_peak,
)
mah_params_e2, mah_params_l2, frac_early2 = _res2
for p, p2 in zip(mah_params_e, mah_params_e2):
assert np.allclose(p, p2)

_res3 = mcdpk.mc_diffmah_params_singlecen(
*args,
t_peak=mah_params_l.t_peak,
)
mah_params_e3, mah_params_l3, frac_early3 = _res3
for p, p2 in zip(mah_params_l, mah_params_l3):
assert np.allclose(p, p2)


def test_mc_diffmah_cenpop():
ran_key = jran.key(0)
t_0 = 13.0
lgt0 = np.log10(t_0)

n_halos = 450
lgm_key, t_obs_key, t_peak_key, ran_key = jran.split(ran_key, 4)
lgm_obs = jran.uniform(lgm_key, minval=10, maxval=15, shape=(n_halos,))
t_obs = jran.uniform(t_obs_key, minval=2, maxval=15, shape=(n_halos,))
t_peak = jran.uniform(t_obs_key, minval=2, maxval=15, shape=(n_halos,))

args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0
_res = mcdpk.mc_diffmah_cenpop(*args)
mah_params, mah_params_early, mah_params_late, frac_early_cens, mc_early = _res
for x in mah_params:
assert x.shape == (n_halos,)
assert np.all(np.isfinite(x))

_res = mcdpk.mc_diffmah_cenpop(*args, t_peak=t_peak)
mah_params, mah_params_early, mah_params_late, frac_early_cens, mc_early = _res
for x in mah_params:
assert x.shape == (n_halos,)
assert np.all(np.isfinite(x))


def test_predict_mah_moments_singlebin():
ran_key = jran.key(0)
t_0 = 13.0
Expand Down
Loading

0 comments on commit fca9835

Please sign in to comment.