Skip to content

Commit

Permalink
Implement mc_diffmah_params_singlesat with t_peak option
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Dec 4, 2024
1 parent 4d0cfdd commit 3770918
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 9 deletions.
95 changes: 88 additions & 7 deletions diffmah/diffmahpop_kernels/mc_bimod_sats.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
return mah_params_early, mah_params_late, frac_early_cens


@jjit
def _mean_diffmah_params_t_peak(diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key):
mah_params_early = _mean_diffmah_params_early_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

mah_params_late = _mean_diffmah_params_late_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

frac_early_cens = _frac_early_cens_kern(diffmahpop_params, lgm_obs, t_obs)

return mah_params_early, mah_params_late, frac_early_cens


@jjit
def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
model_params = get_component_model_params(diffmahpop_params)
Expand Down Expand Up @@ -77,6 +92,35 @@ def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
return mah_params


@jjit
def _mean_diffmah_params_early_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
):
model_params = get_component_model_params(diffmahpop_params)
(
tp_pdf_cens_params,
tp_pdf_sats_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
fec_params,
cov_params,
) = model_params

tpc_key, ran_key = jran.split(ran_key, 2)

logm0 = _pred_logm0_kern_early(logm0_sats_params, lgm_obs, t_obs, t_peak)
logtc = _pred_logtc_early(logtc_params, lgm_obs, t_obs, t_peak)
early_index = _pred_early_index_early(early_index_params, lgm_obs, t_obs, t_peak)
late_index = _pred_late_index_early(late_index_params, lgm_obs)

mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak)

return mah_params


@jjit
def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
model_params = get_component_model_params(diffmahpop_params)
Expand Down Expand Up @@ -108,8 +152,45 @@ def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):


@jjit
def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0):
_res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
def _mean_diffmah_params_late_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
):
model_params = get_component_model_params(diffmahpop_params)
(
tp_pdf_cens_params,
tp_pdf_sats_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
fec_params,
cov_params,
) = model_params

tpc_key, ran_key = jran.split(ran_key, 2)

logm0 = _pred_logm0_kern_late(logm0_sats_params, lgm_obs, t_obs, t_peak)
logtc = _pred_logtc_late(logtc_params, lgm_obs, t_obs, t_peak)
early_index = _pred_early_index_late(early_index_params, lgm_obs, t_obs, t_peak)
late_index = _pred_late_index_late(late_index_params, lgm_obs)

mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak)

return mah_params


@jjit
def mc_diffmah_params_singlesat(
diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0, t_peak=None
):
if t_peak is None:
_res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
else:
_res = _mean_diffmah_params_t_peak(
diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key
)

(mean_mah_params_early, mean_mah_params_late, frac_early_cens) = _res

mean_mah_u_params_early = get_unbounded_mah_params(mean_mah_params_early)
Expand Down Expand Up @@ -139,8 +220,8 @@ def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0


@jjit
def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
_res = mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
def _mc_diffmah_singlesat(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
_res = mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0)
mah_params_early, mah_params_late, frac_early_cens = _res
dmhdt_early, log_mah_early = mah_singlehalo(mah_params_early, tarr, lgt0)
dmhdt_late, log_mah_late = mah_singlehalo(mah_params_late, tarr, lgt0)
Expand All @@ -152,7 +233,7 @@ def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0


_V = (None, None, 0, 0, 0, None)
_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V))
_mc_diffmah_singlesat_vmap_kern = jjit(vmap(_mc_diffmah_singlesat, in_axes=_V))


@partial(jjit, static_argnames=["n_mc"])
Expand All @@ -161,7 +242,7 @@ def _mc_diffmah_halo_sample(
):
zz = jnp.zeros(n_mc)
ran_keys = jran.split(ran_key, n_mc)
return _mc_diffmah_singlecen_vmap_kern(
return _mc_diffmah_singlesat_vmap_kern(
diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0
)

Expand All @@ -173,7 +254,7 @@ def mc_satpop(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0):
ran_keys = jran.split(ran_key, n_mc + 1)
dmah_keys = ran_keys[:-1]
uran_key = ran_keys[-1]
_res = _mc_diffmah_singlecen_vmap_kern(
_res = _mc_diffmah_singlesat_vmap_kern(
diffmahpop_params, tarr, lgm_obs, t_obs, dmah_keys, lgt0
)
p_e, dmhdt_early, log_mah_early = _res[0:3]
Expand Down
26 changes: 24 additions & 2 deletions diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_bimod_sats.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,42 @@ def test_mean_diffmah_params():
assert np.all(np.isfinite(_x))


def test_mc_diffmah_params_singlecen():
def test_mc_diffmah_params_singlesat():
ran_key = jran.key(0)
t_0 = 13.0
lgt0 = np.log10(t_0)
t_obs = 10.0
lgmarr = np.linspace(10, 15, 20)
for lgm_obs in lgmarr:
args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0)
_res = mcdpk.mc_diffmah_params_singlecen(*args)
_res = mcdpk.mc_diffmah_params_singlesat(*args)
mah_params_e, mah_params_l, frac_early = _res
assert np.all(np.isfinite(mah_params_e.logtc))
assert np.all(np.isfinite(mah_params_l.logtc))


def test_mc_diffmah_params_singlesat_agrees_with_fixed_t_peak_version():
ran_key = jran.key(0)
t_0 = 13.0
lgt0 = np.log10(t_0)
t_obs = 10.0
lgmarr = np.linspace(10, 15, 20)
for lgm_obs in lgmarr:
args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0)
_res = mcdpk.mc_diffmah_params_singlesat(*args)
mah_params_e, mah_params_l, frac_early = _res

_res2 = mcdpk.mc_diffmah_params_singlesat(*args, t_peak=mah_params_e.t_peak)
mah_params_e2, mah_params_l2, frac_early2 = _res2
for p, p2 in zip(mah_params_e, mah_params_e2):
assert np.allclose(p, p2)

_res3 = mcdpk.mc_diffmah_params_singlesat(*args, t_peak=mah_params_l.t_peak)
mah_params_e3, mah_params_l3, frac_early3 = _res3
for p, p2 in zip(mah_params_l, mah_params_l3):
assert np.allclose(p, p2)


def test_predict_mah_moments_singlebin():
ran_key = jran.key(0)
t_0 = 13.0
Expand Down

0 comments on commit 3770918

Please sign in to comment.