diff --git a/diffmah/diffmahpop_kernels/mc_bimod_sats.py b/diffmah/diffmahpop_kernels/mc_bimod_sats.py index 2a7e1ff..2c1e529 100644 --- a/diffmah/diffmahpop_kernels/mc_bimod_sats.py +++ b/diffmah/diffmahpop_kernels/mc_bimod_sats.py @@ -47,6 +47,21 @@ def _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): return mah_params_early, mah_params_late, frac_early_cens +@jjit +def _mean_diffmah_params_t_peak(diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key): + mah_params_early = _mean_diffmah_params_early_t_peak( + diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key + ) + + mah_params_late = _mean_diffmah_params_late_t_peak( + diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key + ) + + frac_early_cens = _frac_early_cens_kern(diffmahpop_params, lgm_obs, t_obs) + + return mah_params_early, mah_params_late, frac_early_cens + + @jjit def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): model_params = get_component_model_params(diffmahpop_params) @@ -77,6 +92,35 @@ def _mean_diffmah_params_early(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) return mah_params +@jjit +def _mean_diffmah_params_early_t_peak( + diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key +): + model_params = get_component_model_params(diffmahpop_params) + ( + tp_pdf_cens_params, + tp_pdf_sats_params, + logm0_params, + logm0_sats_params, + logtc_params, + early_index_params, + late_index_params, + fec_params, + cov_params, + ) = model_params + + tpc_key, ran_key = jran.split(ran_key, 2) + + logm0 = _pred_logm0_kern_early(logm0_sats_params, lgm_obs, t_obs, t_peak) + logtc = _pred_logtc_early(logtc_params, lgm_obs, t_obs, t_peak) + early_index = _pred_early_index_early(early_index_params, lgm_obs, t_obs, t_peak) + late_index = _pred_late_index_early(late_index_params, lgm_obs) + + mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak) + + return mah_params + + @jjit def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): model_params = get_component_model_params(diffmahpop_params) @@ -108,8 +152,45 @@ def _mean_diffmah_params_late(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): @jjit -def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - _res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) +def _mean_diffmah_params_late_t_peak( + diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key +): + model_params = get_component_model_params(diffmahpop_params) + ( + tp_pdf_cens_params, + tp_pdf_sats_params, + logm0_params, + logm0_sats_params, + logtc_params, + early_index_params, + late_index_params, + fec_params, + cov_params, + ) = model_params + + tpc_key, ran_key = jran.split(ran_key, 2) + + logm0 = _pred_logm0_kern_late(logm0_sats_params, lgm_obs, t_obs, t_peak) + logtc = _pred_logtc_late(logtc_params, lgm_obs, t_obs, t_peak) + early_index = _pred_early_index_late(early_index_params, lgm_obs, t_obs, t_peak) + late_index = _pred_late_index_late(late_index_params, lgm_obs) + + mah_params = DiffmahParams(logm0, logtc, early_index, late_index, t_peak) + + return mah_params + + +@jjit +def mc_diffmah_params_singlesat( + diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0, t_peak=None +): + if t_peak is None: + _res = _mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) + else: + _res = _mean_diffmah_params_t_peak( + diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key + ) + (mean_mah_params_early, mean_mah_params_late, frac_early_cens) = _res mean_mah_u_params_early = get_unbounded_mah_params(mean_mah_params_early) @@ -139,8 +220,8 @@ def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 @jjit -def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) +def _mc_diffmah_singlesat(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): + _res = mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) mah_params_early, mah_params_late, frac_early_cens = _res dmhdt_early, log_mah_early = mah_singlehalo(mah_params_early, tarr, lgt0) dmhdt_late, log_mah_late = mah_singlehalo(mah_params_late, tarr, lgt0) @@ -152,7 +233,7 @@ def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 _V = (None, None, 0, 0, 0, None) -_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V)) +_mc_diffmah_singlesat_vmap_kern = jjit(vmap(_mc_diffmah_singlesat, in_axes=_V)) @partial(jjit, static_argnames=["n_mc"]) @@ -161,7 +242,7 @@ def _mc_diffmah_halo_sample( ): zz = jnp.zeros(n_mc) ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_singlecen_vmap_kern( + return _mc_diffmah_singlesat_vmap_kern( diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 ) @@ -173,7 +254,7 @@ def mc_satpop(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): ran_keys = jran.split(ran_key, n_mc + 1) dmah_keys = ran_keys[:-1] uran_key = ran_keys[-1] - _res = _mc_diffmah_singlecen_vmap_kern( + _res = _mc_diffmah_singlesat_vmap_kern( diffmahpop_params, tarr, lgm_obs, t_obs, dmah_keys, lgt0 ) p_e, dmhdt_early, log_mah_early = _res[0:3] diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_bimod_sats.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_bimod_sats.py index 25a55ef..9dc6387 100644 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_bimod_sats.py +++ b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_bimod_sats.py @@ -56,7 +56,7 @@ def test_mean_diffmah_params(): assert np.all(np.isfinite(_x)) -def test_mc_diffmah_params_singlecen(): +def test_mc_diffmah_params_singlesat(): ran_key = jran.key(0) t_0 = 13.0 lgt0 = np.log10(t_0) @@ -64,12 +64,34 @@ def test_mc_diffmah_params_singlecen(): lgmarr = np.linspace(10, 15, 20) for lgm_obs in lgmarr: args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.mc_diffmah_params_singlecen(*args) + _res = mcdpk.mc_diffmah_params_singlesat(*args) mah_params_e, mah_params_l, frac_early = _res assert np.all(np.isfinite(mah_params_e.logtc)) assert np.all(np.isfinite(mah_params_l.logtc)) +def test_mc_diffmah_params_singlesat_agrees_with_fixed_t_peak_version(): + ran_key = jran.key(0) + t_0 = 13.0 + lgt0 = np.log10(t_0) + t_obs = 10.0 + lgmarr = np.linspace(10, 15, 20) + for lgm_obs in lgmarr: + args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) + _res = mcdpk.mc_diffmah_params_singlesat(*args) + mah_params_e, mah_params_l, frac_early = _res + + _res2 = mcdpk.mc_diffmah_params_singlesat(*args, t_peak=mah_params_e.t_peak) + mah_params_e2, mah_params_l2, frac_early2 = _res2 + for p, p2 in zip(mah_params_e, mah_params_e2): + assert np.allclose(p, p2) + + _res3 = mcdpk.mc_diffmah_params_singlesat(*args, t_peak=mah_params_l.t_peak) + mah_params_e3, mah_params_l3, frac_early3 = _res3 + for p, p2 in zip(mah_params_l, mah_params_l3): + assert np.allclose(p, p2) + + def test_predict_mah_moments_singlebin(): ran_key = jran.key(0) t_0 = 13.0