From 2f3e4941ec459e5946bac5e5b4259a38b04cca72 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Mon, 21 Oct 2024 13:35:23 -0500 Subject: [PATCH] Remove obsolete formulations of DiffmahPop --- diffmah/diffmahpop_kernels/cens_fithelp.py | 210 --------------- .../diffmahpop_kernels/censat_var_fithelp.py | 35 --- .../diffmahpop_kernels/diffmahpop_params.py | 224 ---------------- .../diffmahpop_params_censat.py | 243 ----------------- .../diffmahpop_params_monocensat.py | 230 ---------------- diffmah/diffmahpop_kernels/late_index_pop.py | 89 ------- .../logm0_kernels/__init__.py | 0 .../logm0_kernels/logm0_c0_kernels.py | 127 --------- .../logm0_kernels/logm0_c1_kernels.py | 138 ---------- .../logm0_kernels/logm0_pop.py | 73 ----- .../logm0_kernels/tests/__init__.py | 0 .../tests/test_logm0_c0_kernels.py | 65 ----- .../tests/test_logm0_c1_kernels.py | 65 ----- .../logm0_kernels/tests/test_logm0_pop.py | 86 ------ diffmah/diffmahpop_kernels/logtc_pop.py | 117 -------- .../mc_diffmahpop_kernels.py | 195 -------------- .../mc_diffmahpop_kernels_cens.py | 200 -------------- .../mc_diffmahpop_kernels_censat.py | 249 ------------------ .../mc_diffmahpop_kernels_monocens.py | 130 --------- ...diffmahpop_kernels_monocens_fixed_tpeak.py | 123 --------- .../mc_diffmahpop_kernels_monosats.py | 132 ---------- .../mc_diffmahpop_kernels_sats.py | 135 ---------- .../mean_param_fitting_kernels.py | 125 --------- .../diffmahpop_kernels/monocens_fithelp.py | 98 ------- .../monocens_fixed_tpeak_fithelp.py | 140 ---------- .../diffmahpop_kernels/monocensat_fithelp.py | 78 ------ .../diffmahpop_kernels/monosats_fithelp.py | 98 ------- .../t_peak_kernels/tests/test_tp_pdf_cens.py | 96 ------- .../tests/test_tp_pdf_monocens.py | 96 ------- .../t_peak_kernels/tp_pdf_monocens.py | 169 ------------ .../tests/test_cens_fithelp.py | 41 --- .../tests/test_diffmahpop_params.py | 23 -- .../tests/test_diffmahpop_params_censat.py | 23 -- .../test_diffmahpop_params_monocensat.py | 23 -- .../tests/test_early_index_pop.py | 70 ----- .../tests/test_late_index_pop.py | 68 ----- .../tests/test_logtc_pop.py | 62 ----- .../tests/test_mc_diffmahpop.py | 112 -------- .../tests/test_mc_diffmahpop_cens.py | 113 -------- .../tests/test_mc_diffmahpop_censat.py | 137 ---------- .../tests/test_mc_diffmahpop_monocens.py | 91 ------- ...test_mc_diffmahpop_monocens_fixed_tpeak.py | 103 -------- .../tests/test_mc_diffmahpop_monosats.py | 98 ------- .../tests/test_mc_diffmahpop_sats.py | 90 ------- .../tests/test_monocens_fithelp.py | 40 --- .../test_monocens_fixed_tpeak_fithelp.py | 44 ---- .../tests/test_monocensat_fithelp.py | 49 ---- .../tests/test_monosats_fithelp.py | 40 --- .../tests/test_variance_fithelp.py | 38 --- .../diffmahpop_kernels/variance_fithelp.py | 92 ------- 50 files changed, 5123 deletions(-) delete mode 100644 diffmah/diffmahpop_kernels/cens_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/censat_var_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/diffmahpop_params.py delete mode 100644 diffmah/diffmahpop_kernels/diffmahpop_params_censat.py delete mode 100644 diffmah/diffmahpop_kernels/diffmahpop_params_monocensat.py delete mode 100644 diffmah/diffmahpop_kernels/late_index_pop.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/__init__.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/logm0_pop.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/tests/__init__.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c0_kernels.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c1_kernels.py delete mode 100644 diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py delete mode 100644 diffmah/diffmahpop_kernels/logtc_pop.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_cens.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_censat.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens_fixed_tpeak.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monosats.py delete mode 100644 diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_sats.py delete mode 100644 diffmah/diffmahpop_kernels/mean_param_fitting_kernels.py delete mode 100644 diffmah/diffmahpop_kernels/monocens_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/monocens_fixed_tpeak_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/monocensat_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/monosats_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_cens.py delete mode 100644 diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_monocens.py delete mode 100644 diffmah/diffmahpop_kernels/t_peak_kernels/tp_pdf_monocens.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_cens_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_diffmahpop_params.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_diffmahpop_params_censat.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_diffmahpop_params_monocensat.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_early_index_pop.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_late_index_pop.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_logtc_pop.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_cens.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_censat.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens_fixed_tpeak.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monosats.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_sats.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_monocens_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_monocens_fixed_tpeak_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_monocensat_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_monosats_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/tests/test_variance_fithelp.py delete mode 100644 diffmah/diffmahpop_kernels/variance_fithelp.py diff --git a/diffmah/diffmahpop_kernels/cens_fithelp.py b/diffmah/diffmahpop_kernels/cens_fithelp.py deleted file mode 100644 index da7edb4..0000000 --- a/diffmah/diffmahpop_kernels/cens_fithelp.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -""" - -import numpy as np -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import value_and_grad, vmap - -from ..diffmah_kernels import DEFAULT_MAH_PARAMS, mah_halopop -from . import diffmahpop_params as dpp -from . import mc_diffmahpop_kernels as mcdk - -T_OBS_FIT_MIN = 1.5 -T_TARGET_VAR_MIN = 2.0 -T_TARGET_VAR_MAX = 0.1 -LGSMAH_MIN = -15.0 -EPS = 1e-3 - - -@jjit -def _get_var_weights(tarr, t_obs): - msk_std = tarr < T_TARGET_VAR_MIN - msk_std |= tarr > t_obs - T_TARGET_VAR_MAX - var_weights = jnp.where(msk_std, 0.0, 1.0) - return var_weights - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _wmse(x, y, w): - d = y - x - return jnp.average(d * d, weights=w) - - -@jjit -def predict_mah_targets_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = mcdk._mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) = _res - - f = ftpt0.reshape((-1, 1)) - delta_log_mah_tpt0 = log_mah_tpt0 - lgm_obs - delta_log_mah_tp = log_mah_tp - lgm_obs - mean_delta_log_mah = jnp.mean( - f * delta_log_mah_tpt0 + (1 - f) * delta_log_mah_tp, axis=0 - ) - - std_log_mah = jnp.std(f * log_mah_tpt0 + (1 - f) * log_mah_tp, axis=0) - - dmhdt_tpt0 = jnp.clip(dmhdt_tpt0, 10**LGSMAH_MIN) # make log-safe - dmhdt_tp = jnp.clip(dmhdt_tp, 10**LGSMAH_MIN) # make log-safe - - lgsmah_tpt0 = jnp.log10(dmhdt_tpt0) - log_mah_tpt0 # compute lgsmah - lgsmah_tpt0 = jnp.clip(lgsmah_tpt0, LGSMAH_MIN) # impose lgsMAH clip - - lgsmah_tp = jnp.log10(dmhdt_tp) - log_mah_tpt0 # compute lgsmah - lgsmah_tp = jnp.clip(lgsmah_tp, LGSMAH_MIN) # impose lgsMAH clip - - weights_ftpt0 = jnp.concatenate((ftpt0, 1 - ftpt0)) - lgsmah_pred = jnp.concatenate((lgsmah_tpt0, lgsmah_tp)) - - frac_peaked = jnp.average(lgsmah_pred == LGSMAH_MIN, axis=0, weights=weights_ftpt0) - - return mean_delta_log_mah, std_log_mah, frac_peaked - - -@jjit -def _loss_mah_moments_singlebin( - diffmahpop_params, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - target_mean_delta_log_mah, - target_std_log_mah, - target_frac_peaked, -): - _preds = predict_mah_targets_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mean_delta_log_mah, std_log_mah, frac_peaked = _preds - loss = _mse(mean_delta_log_mah, target_mean_delta_log_mah) - - std_weights = _get_var_weights(tarr, t_obs) - loss = loss + _wmse(std_log_mah, target_std_log_mah, std_weights) - # loss = loss + _mse(frac_peaked, target_frac_peaked) - return loss - - -_A = (None, 0, 0, 0, 0, None, 0, 0, 0) -_loss_mah_moments_multibin = jjit(vmap(_loss_mah_moments_singlebin, in_axes=_A)) - - -@jjit -def loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_keys, - lgt0, - target_mean_delta_log_mah_matrix, - target_std_log_mah_matrix, - target_frac_peaked_matrix, -): - losses = _loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_keys, - lgt0, - target_mean_delta_log_mah_matrix, - target_std_log_mah_matrix, - target_frac_peaked_matrix, - ) - return jnp.mean(losses) - - -@jjit -def global_loss_kern(diffmahpop_u_params, loss_data): - diffmahpop_params = dpp.get_diffmahpop_params_from_u_params(diffmahpop_u_params) - return loss_mah_moments_multibin(diffmahpop_params, *loss_data) - - -global_loss_and_grad_kern = jjit(value_and_grad(global_loss_kern)) - - -def compute_targets_singlebin(halo_samples, t_obs_samples, lgm_obs, t_obs, lgt0): - it_obs = np.argmin(np.abs(t_obs_samples - t_obs)) - t_obs = t_obs_samples[it_obs] - - cens = halo_samples[it_obs] - lgm_obs_arr = np.sort(np.unique(cens["lgm_obs"])) - mah_keys = ("logm0", "logtc", "early_index", "late_index") - mah_params = DEFAULT_MAH_PARAMS._make([cens[key] for key in mah_keys]) - ilgm_obs = np.argmin(np.abs(lgm_obs_arr - lgm_obs)) - lgm_obs = lgm_obs_arr[ilgm_obs] - mmsk = cens["lgm_obs"] == lgm_obs - mah_params_target = DEFAULT_MAH_PARAMS._make([x[mmsk] for x in mah_params]) - t_peak_target = cens["t_peak"][mmsk] - tarr = np.linspace(T_OBS_FIT_MIN, t_obs - EPS, 50) - dmhdt, log_mah = mah_halopop(mah_params_target, tarr, t_peak_target, lgt0) - - lgm_obs_sample = mah_halopop( - mah_params_target, np.zeros(1) + t_obs, t_peak_target, lgt0 - )[1][:, 0] - log_mah_rescaled = log_mah - (lgm_obs_sample.reshape((-1, 1)) - lgm_obs) - - delta_log_mah = log_mah_rescaled - lgm_obs - mean_delta_log_mah = np.mean(delta_log_mah, axis=0) - std_log_mah = np.std(log_mah_rescaled, axis=0) - - dmhdt = jnp.clip(dmhdt, 10**LGSMAH_MIN) # make log-safe - lgsmah = jnp.log10(dmhdt) - log_mah - lgsmah = jnp.clip(lgsmah, LGSMAH_MIN) - - frac_peaked = np.mean(lgsmah == LGSMAH_MIN, axis=0) - - return lgm_obs, t_obs, tarr, mean_delta_log_mah, std_log_mah, frac_peaked - - -def get_random_target_collection( - halo_samples, t_obs_samples, lgt0, ran_key, n_targets=100 -): - ran_key, lgm_obs_key, t_obs_key = jran.split(ran_key, 3) - - target_collector = [] - for __ in range(n_targets): - lgm_obs = jran.uniform(lgm_obs_key, minval=11, maxval=15, shape=()) - t_obs = jran.uniform(t_obs_key, minval=4, maxval=13.5, shape=()) - targets = compute_targets_singlebin( - halo_samples, t_obs_samples, lgm_obs, t_obs, lgt0 - ) - target_collector.append(targets) - lgm_obs_arr = np.array([x[0] for x in target_collector]) - t_obs_arr = np.array([x[1] for x in target_collector]) - tarr_matrix = np.array([x[2] for x in target_collector]) - mean_delta_log_mah_matrix = np.array([x[3] for x in target_collector]) - std_log_mahs_matrix = np.array([x[4] for x in target_collector]) - frac_peaked_matrix = np.array([x[5] for x in target_collector]) - - return ( - lgm_obs_arr, - t_obs_arr, - tarr_matrix, - mean_delta_log_mah_matrix, - std_log_mahs_matrix, - frac_peaked_matrix, - ) diff --git a/diffmah/diffmahpop_kernels/censat_var_fithelp.py b/diffmah/diffmahpop_kernels/censat_var_fithelp.py deleted file mode 100644 index 64e0193..0000000 --- a/diffmah/diffmahpop_kernels/censat_var_fithelp.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import numpy as jnp - -from . import mc_diffmahpop_kernels as mcdk - -T_OBS_FIT_MIN = 0.5 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_mah_moments_singlebin_cens( - diffmahpop_params, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - target_mean_log_mah, - target_std_log_mah, -): - _preds = mcdk.predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mean_log_mah, std_log_mah = _preds - loss = _mse(mean_log_mah, target_mean_log_mah) - loss = loss + _mse(std_log_mah, target_std_log_mah) - return loss diff --git a/diffmah/diffmahpop_kernels/diffmahpop_params.py b/diffmah/diffmahpop_kernels/diffmahpop_params.py deleted file mode 100644 index be6e91c..0000000 --- a/diffmah/diffmahpop_kernels/diffmahpop_params.py +++ /dev/null @@ -1,224 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit - -from . import covariance_kernels, early_index_pop, ftpt0_cens, late_index_pop, logtc_pop -from .logm0_kernels import logm0_pop -from .t_peak_kernels import tp_pdf_cens - -DEFAULT_DIFFMAHPOP_PDICT = OrderedDict() -COMPONENT_PDICTS = ( - ftpt0_cens.DEFAULT_FTPT0_PDICT, - tp_pdf_cens.DEFAULT_TPCENS_PDICT, - logm0_pop.DEFAULT_LOGM0_PDICT, - logtc_pop.LOGTC_PDICT, - early_index_pop.EARLY_INDEX_PDICT, - late_index_pop.LATE_INDEX_PDICT, - covariance_kernels.DEFAULT_COV_PDICT, -) -for pdict in COMPONENT_PDICTS: - DEFAULT_DIFFMAHPOP_PDICT.update(pdict) -DiffmahPop_Params = namedtuple("DiffmahPop_Params", DEFAULT_DIFFMAHPOP_PDICT.keys()) -DEFAULT_DIFFMAHPOP_PARAMS = DiffmahPop_Params(**DEFAULT_DIFFMAHPOP_PDICT) - - -COMPONENT_U_PDICTS = ( - ftpt0_cens.DEFAULT_FTPT0_U_PARAMS._asdict(), - tp_pdf_cens.DEFAULT_TPCENS_U_PARAMS._asdict(), - logm0_pop.DEFAULT_LOGM0POP_U_PARAMS._asdict(), - logtc_pop.DEFAULT_LOGTC_U_PARAMS._asdict(), - early_index_pop.DEFAULT_EARLY_INDEX_U_PARAMS._asdict(), - late_index_pop.DEFAULT_LATE_INDEX_U_PARAMS._asdict(), - covariance_kernels.DEFAULT_COV_U_PARAMS._asdict(), -) -DEFAULT_DIFFMAHPOP_U_PDICT = OrderedDict() -for updict in COMPONENT_U_PDICTS: - DEFAULT_DIFFMAHPOP_U_PDICT.update(updict) -DiffmahPop_UParams = namedtuple("DiffmahPop_UParams", DEFAULT_DIFFMAHPOP_U_PDICT.keys()) -DEFAULT_DIFFMAHPOP_U_PARAMS = DiffmahPop_UParams(**DEFAULT_DIFFMAHPOP_U_PDICT) - - -@jjit -def get_component_model_params(diffmahpop_params): - ftpt0_cens_params = ftpt0_cens.FTPT0_Params( - *[getattr(diffmahpop_params, key) for key in ftpt0_cens.FTPT0_Params._fields] - ) - tp_pdf_cens_params = tp_pdf_cens.TPCens_Params( - *[getattr(diffmahpop_params, key) for key in tp_pdf_cens.TPCens_Params._fields] - ) - logm0_params = logm0_pop.LGM0Pop_Params( - *[getattr(diffmahpop_params, key) for key in logm0_pop.LGM0Pop_Params._fields] - ) - logtc_params = logtc_pop.Logtc_Params( - *[getattr(diffmahpop_params, key) for key in logtc_pop.Logtc_Params._fields] - ) - early_index_params = early_index_pop.EarlyIndex_Params( - *[ - getattr(diffmahpop_params, key) - for key in early_index_pop.EarlyIndex_Params._fields - ] - ) - late_index_params = late_index_pop.LateIndex_Params( - *[ - getattr(diffmahpop_params, key) - for key in late_index_pop.LateIndex_Params._fields - ] - ) - cov_params = covariance_kernels.CovParams( - *[ - getattr(diffmahpop_params, key) - for key in covariance_kernels.CovParams._fields - ] - ) - return ( - ftpt0_cens_params, - tp_pdf_cens_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) - - -@jjit -def get_component_model_u_params(diffmahpop_u_params): - ftpt0_cens_u_params = ftpt0_cens.FTPT0_UParams( - *[getattr(diffmahpop_u_params, key) for key in ftpt0_cens.FTPT0_UParams._fields] - ) - tp_pdf_cens_u_params = tp_pdf_cens.TPCens_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in tp_pdf_cens.TPCens_UParams._fields - ] - ) - logm0_u_params = logm0_pop.LGM0Pop_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in logm0_pop.LGM0Pop_UParams._fields - ] - ) - logtc_u_params = logtc_pop.Logtc_UParams( - *[getattr(diffmahpop_u_params, key) for key in logtc_pop.Logtc_UParams._fields] - ) - early_index_u_params = early_index_pop.EarlyIndex_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in early_index_pop.EarlyIndex_UParams._fields - ] - ) - late_index_u_params = late_index_pop.LateIndex_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in late_index_pop.LateIndex_UParams._fields - ] - ) - cov_u_params = covariance_kernels.CovUParams( - *[ - getattr(diffmahpop_u_params, key) - for key in covariance_kernels.CovUParams._fields - ] - ) - - return ( - ftpt0_cens_u_params, - tp_pdf_cens_u_params, - logm0_u_params, - logtc_u_params, - early_index_u_params, - late_index_u_params, - cov_u_params, - ) - - -@jjit -def get_diffmahpop_params_from_u_params(diffmahpop_u_params): - component_model_u_params = get_component_model_u_params(diffmahpop_u_params) - ftpt0_u_params, tpc_u_params, logm0_u_params = component_model_u_params[:3] - logtc_u_params = component_model_u_params[3] - early_index_u_params, late_index_u_params = component_model_u_params[4:6] - cov_u_params = component_model_u_params[6] - - ftpt0_cens_params = ftpt0_cens.get_bounded_ftpt0_params(ftpt0_u_params) - tpc_params = tp_pdf_cens.get_bounded_tp_cens_params(tpc_u_params) - logm0_params = logm0_pop.get_bounded_m0pop_params(logm0_u_params) - logtc_params = logtc_pop.get_bounded_logtc_params(logtc_u_params) - early_index_params = early_index_pop.get_bounded_early_index_params( - early_index_u_params - ) - late_index_params = late_index_pop.get_bounded_late_index_params( - late_index_u_params - ) - cov_params = covariance_kernels.get_bounded_cov_params(cov_u_params) - - component_model_params = ( - ftpt0_cens_params, - tpc_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) - diffmahpop_params = DEFAULT_DIFFMAHPOP_PARAMS._make(DEFAULT_DIFFMAHPOP_PARAMS) - for params in component_model_params: - diffmahpop_params = diffmahpop_params._replace(**params._asdict()) - - return diffmahpop_params - - -@jjit -def get_diffmahpop_u_params_from_params(diffmahpop_params): - component_model_params = get_component_model_params(diffmahpop_params) - ftpt0_params, tpc_params, logm0_params = component_model_params[:3] - logtc_params = component_model_params[3] - early_index_params, late_index_params = component_model_params[4:6] - cov_params = component_model_params[6] - - ftpt0_u_params = ftpt0_cens.get_unbounded_ftpt0_params(ftpt0_params) - tpc_u_params = tp_pdf_cens.get_unbounded_tp_cens_params(tpc_params) - logm0_u_params = logm0_pop.get_unbounded_m0pop_params(logm0_params) - logtc_u_params = logtc_pop.get_unbounded_logtc_params(logtc_params) - early_index_u_params = early_index_pop.get_unbounded_early_index_params( - early_index_params - ) - late_index_u_params = late_index_pop.get_unbounded_late_index_params( - late_index_params - ) - cov_u_params = covariance_kernels.get_unbounded_cov_params(cov_params) - - component_model_u_params = ( - ftpt0_u_params, - tpc_u_params, - logm0_u_params, - logtc_u_params, - early_index_u_params, - late_index_u_params, - cov_u_params, - ) - diffmahpop_u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._make(DEFAULT_DIFFMAHPOP_U_PARAMS) - for u_params in component_model_u_params: - diffmahpop_u_params = diffmahpop_u_params._replace(**u_params._asdict()) - - return diffmahpop_u_params - - -@jjit -def _get_all_diffmahpop_params_from_varied( - varied_params, default_params=DEFAULT_DIFFMAHPOP_PARAMS -): - diffmahpop_params = default_params._replace(**varied_params._asdict()) - return diffmahpop_params - - -def get_varied_params_by_exclusion(all_params, excluded_pnames): - gen = zip(all_params._fields, all_params) - varied_pdict = OrderedDict( - [(name, float(x)) for (name, x) in gen if name not in excluded_pnames] - ) - VariedParams = namedtuple("VariedParams", varied_pdict.keys()) - varied_params = VariedParams(**varied_pdict) - return varied_params diff --git a/diffmah/diffmahpop_kernels/diffmahpop_params_censat.py b/diffmah/diffmahpop_kernels/diffmahpop_params_censat.py deleted file mode 100644 index 5dd3640..0000000 --- a/diffmah/diffmahpop_kernels/diffmahpop_params_censat.py +++ /dev/null @@ -1,243 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit - -from . import covariance_kernels, early_index_pop, ftpt0_cens, late_index_pop, logtc_pop -from .logm0_kernels import logm0_pop -from .t_peak_kernels import tp_pdf_cens, tp_pdf_sats - -DEFAULT_DIFFMAHPOP_PDICT = OrderedDict() -COMPONENT_PDICTS = ( - ftpt0_cens.DEFAULT_FTPT0_PDICT, - tp_pdf_cens.DEFAULT_TPCENS_PDICT, - tp_pdf_sats.DEFAULT_TP_SATS_PDICT, - logm0_pop.DEFAULT_LOGM0_PDICT, - logtc_pop.LOGTC_PDICT, - early_index_pop.EARLY_INDEX_PDICT, - late_index_pop.LATE_INDEX_PDICT, - covariance_kernels.DEFAULT_COV_PDICT, -) -for pdict in COMPONENT_PDICTS: - DEFAULT_DIFFMAHPOP_PDICT.update(pdict) -DiffmahPop_Params = namedtuple("DiffmahPop_Params", DEFAULT_DIFFMAHPOP_PDICT.keys()) -DEFAULT_DIFFMAHPOP_PARAMS = DiffmahPop_Params(**DEFAULT_DIFFMAHPOP_PDICT) - - -COMPONENT_U_PDICTS = ( - ftpt0_cens.DEFAULT_FTPT0_U_PARAMS._asdict(), - tp_pdf_cens.DEFAULT_TPCENS_U_PARAMS._asdict(), - tp_pdf_sats.DEFAULT_TP_SATS_U_PARAMS._asdict(), - logm0_pop.DEFAULT_LOGM0POP_U_PARAMS._asdict(), - logtc_pop.DEFAULT_LOGTC_U_PARAMS._asdict(), - early_index_pop.DEFAULT_EARLY_INDEX_U_PARAMS._asdict(), - late_index_pop.DEFAULT_LATE_INDEX_U_PARAMS._asdict(), - covariance_kernels.DEFAULT_COV_U_PARAMS._asdict(), -) -DEFAULT_DIFFMAHPOP_U_PDICT = OrderedDict() -for updict in COMPONENT_U_PDICTS: - DEFAULT_DIFFMAHPOP_U_PDICT.update(updict) -DiffmahPop_UParams = namedtuple("DiffmahPop_UParams", DEFAULT_DIFFMAHPOP_U_PDICT.keys()) -DEFAULT_DIFFMAHPOP_U_PARAMS = DiffmahPop_UParams(**DEFAULT_DIFFMAHPOP_U_PDICT) - - -@jjit -def get_component_model_params(diffmahpop_params): - ftpt0_cens_params = ftpt0_cens.FTPT0_Params( - *[getattr(diffmahpop_params, key) for key in ftpt0_cens.FTPT0_Params._fields] - ) - tp_pdf_cens_params = tp_pdf_cens.TPCens_Params( - *[getattr(diffmahpop_params, key) for key in tp_pdf_cens.TPCens_Params._fields] - ) - tp_pdf_sats_params = tp_pdf_sats.TP_Sats_Params( - *[getattr(diffmahpop_params, key) for key in tp_pdf_sats.TP_Sats_Params._fields] - ) - logm0_params = logm0_pop.LGM0Pop_Params( - *[getattr(diffmahpop_params, key) for key in logm0_pop.LGM0Pop_Params._fields] - ) - logtc_params = logtc_pop.Logtc_Params( - *[getattr(diffmahpop_params, key) for key in logtc_pop.Logtc_Params._fields] - ) - early_index_params = early_index_pop.EarlyIndex_Params( - *[ - getattr(diffmahpop_params, key) - for key in early_index_pop.EarlyIndex_Params._fields - ] - ) - late_index_params = late_index_pop.LateIndex_Params( - *[ - getattr(diffmahpop_params, key) - for key in late_index_pop.LateIndex_Params._fields - ] - ) - cov_params = covariance_kernels.CovParams( - *[ - getattr(diffmahpop_params, key) - for key in covariance_kernels.CovParams._fields - ] - ) - return ( - ftpt0_cens_params, - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) - - -@jjit -def get_component_model_u_params(diffmahpop_u_params): - ftpt0_cens_u_params = ftpt0_cens.FTPT0_UParams( - *[getattr(diffmahpop_u_params, key) for key in ftpt0_cens.FTPT0_UParams._fields] - ) - tp_pdf_cens_u_params = tp_pdf_cens.TPCens_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in tp_pdf_cens.TPCens_UParams._fields - ] - ) - tp_pdf_sats_u_params = tp_pdf_sats.TP_Sats_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in tp_pdf_sats.TP_Sats_UParams._fields - ] - ) - logm0_u_params = logm0_pop.LGM0Pop_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in logm0_pop.LGM0Pop_UParams._fields - ] - ) - logtc_u_params = logtc_pop.Logtc_UParams( - *[getattr(diffmahpop_u_params, key) for key in logtc_pop.Logtc_UParams._fields] - ) - early_index_u_params = early_index_pop.EarlyIndex_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in early_index_pop.EarlyIndex_UParams._fields - ] - ) - late_index_u_params = late_index_pop.LateIndex_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in late_index_pop.LateIndex_UParams._fields - ] - ) - cov_u_params = covariance_kernels.CovUParams( - *[ - getattr(diffmahpop_u_params, key) - for key in covariance_kernels.CovUParams._fields - ] - ) - - return ( - ftpt0_cens_u_params, - tp_pdf_cens_u_params, - tp_pdf_sats_u_params, - logm0_u_params, - logtc_u_params, - early_index_u_params, - late_index_u_params, - cov_u_params, - ) - - -@jjit -def get_diffmahpop_params_from_u_params(diffmahpop_u_params): - component_model_u_params = get_component_model_u_params(diffmahpop_u_params) - ftpt0_u_params, tpc_u_params, tps_u_params, logm0_u_params = ( - component_model_u_params[:4] - ) - logtc_u_params = component_model_u_params[4] - early_index_u_params, late_index_u_params = component_model_u_params[5:7] - cov_u_params = component_model_u_params[7] - - ftpt0_cens_params = ftpt0_cens.get_bounded_ftpt0_params(ftpt0_u_params) - tpc_params = tp_pdf_cens.get_bounded_tp_cens_params(tpc_u_params) - tps_params = tp_pdf_sats.get_bounded_tp_sat_params(tps_u_params) - logm0_params = logm0_pop.get_bounded_m0pop_params(logm0_u_params) - logtc_params = logtc_pop.get_bounded_logtc_params(logtc_u_params) - early_index_params = early_index_pop.get_bounded_early_index_params( - early_index_u_params - ) - late_index_params = late_index_pop.get_bounded_late_index_params( - late_index_u_params - ) - cov_params = covariance_kernels.get_bounded_cov_params(cov_u_params) - - component_model_params = ( - ftpt0_cens_params, - tpc_params, - tps_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) - diffmahpop_params = DEFAULT_DIFFMAHPOP_PARAMS._make(DEFAULT_DIFFMAHPOP_PARAMS) - for params in component_model_params: - diffmahpop_params = diffmahpop_params._replace(**params._asdict()) - - return diffmahpop_params - - -@jjit -def get_diffmahpop_u_params_from_params(diffmahpop_params): - component_model_params = get_component_model_params(diffmahpop_params) - ftpt0_params, tpc_params, tps_params, logm0_params = component_model_params[:4] - logtc_params = component_model_params[4] - early_index_params, late_index_params = component_model_params[5:7] - cov_params = component_model_params[7] - - ftpt0_u_params = ftpt0_cens.get_unbounded_ftpt0_params(ftpt0_params) - tpc_u_params = tp_pdf_cens.get_unbounded_tp_cens_params(tpc_params) - tps_u_params = tp_pdf_sats.get_unbounded_tp_sat_params(tps_params) - logm0_u_params = logm0_pop.get_unbounded_m0pop_params(logm0_params) - logtc_u_params = logtc_pop.get_unbounded_logtc_params(logtc_params) - early_index_u_params = early_index_pop.get_unbounded_early_index_params( - early_index_params - ) - late_index_u_params = late_index_pop.get_unbounded_late_index_params( - late_index_params - ) - cov_u_params = covariance_kernels.get_unbounded_cov_params(cov_params) - - component_model_u_params = ( - ftpt0_u_params, - tpc_u_params, - tps_u_params, - logm0_u_params, - logtc_u_params, - early_index_u_params, - late_index_u_params, - cov_u_params, - ) - diffmahpop_u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._make(DEFAULT_DIFFMAHPOP_U_PARAMS) - for u_params in component_model_u_params: - diffmahpop_u_params = diffmahpop_u_params._replace(**u_params._asdict()) - - return diffmahpop_u_params - - -@jjit -def _get_all_diffmahpop_params_from_varied( - varied_params, default_params=DEFAULT_DIFFMAHPOP_PARAMS -): - diffmahpop_params = default_params._replace(**varied_params._asdict()) - return diffmahpop_params - - -def get_varied_params_by_exclusion(all_params, excluded_pnames): - gen = zip(all_params._fields, all_params) - varied_pdict = OrderedDict( - [(name, float(x)) for (name, x) in gen if name not in excluded_pnames] - ) - VariedParams = namedtuple("VariedParams", varied_pdict.keys()) - varied_params = VariedParams(**varied_pdict) - return varied_params diff --git a/diffmah/diffmahpop_kernels/diffmahpop_params_monocensat.py b/diffmah/diffmahpop_kernels/diffmahpop_params_monocensat.py deleted file mode 100644 index 09671dc..0000000 --- a/diffmah/diffmahpop_kernels/diffmahpop_params_monocensat.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit - -from . import covariance_kernels, early_index_pop, late_index_pop, logtc_pop -from .logm0_kernels import logm0_pop -from .t_peak_kernels import tp_pdf_monocens, tp_pdf_sats - -DEFAULT_DIFFMAHPOP_PDICT = OrderedDict() -COMPONENT_PDICTS = ( - tp_pdf_monocens.DEFAULT_TPCENS_PDICT, - tp_pdf_sats.DEFAULT_TP_SATS_PDICT, - logm0_pop.DEFAULT_LOGM0_PDICT, - logtc_pop.LOGTC_PDICT, - early_index_pop.EARLY_INDEX_PDICT, - late_index_pop.LATE_INDEX_PDICT, - covariance_kernels.DEFAULT_COV_PDICT, -) -for pdict in COMPONENT_PDICTS: - DEFAULT_DIFFMAHPOP_PDICT.update(pdict) -DiffmahPop_Params = namedtuple("DiffmahPop_Params", DEFAULT_DIFFMAHPOP_PDICT.keys()) -DEFAULT_DIFFMAHPOP_PARAMS = DiffmahPop_Params(**DEFAULT_DIFFMAHPOP_PDICT) - - -COMPONENT_U_PDICTS = ( - tp_pdf_monocens.DEFAULT_TPCENS_U_PARAMS._asdict(), - tp_pdf_sats.DEFAULT_TP_SATS_U_PARAMS._asdict(), - logm0_pop.DEFAULT_LOGM0POP_U_PARAMS._asdict(), - logtc_pop.DEFAULT_LOGTC_U_PARAMS._asdict(), - early_index_pop.DEFAULT_EARLY_INDEX_U_PARAMS._asdict(), - late_index_pop.DEFAULT_LATE_INDEX_U_PARAMS._asdict(), - covariance_kernels.DEFAULT_COV_U_PARAMS._asdict(), -) -DEFAULT_DIFFMAHPOP_U_PDICT = OrderedDict() -for updict in COMPONENT_U_PDICTS: - DEFAULT_DIFFMAHPOP_U_PDICT.update(updict) -DiffmahPop_UParams = namedtuple("DiffmahPop_UParams", DEFAULT_DIFFMAHPOP_U_PDICT.keys()) -DEFAULT_DIFFMAHPOP_U_PARAMS = DiffmahPop_UParams(**DEFAULT_DIFFMAHPOP_U_PDICT) - - -@jjit -def get_component_model_params(diffmahpop_params): - tp_pdf_monocens_params = tp_pdf_monocens.TPCens_Params( - *[ - getattr(diffmahpop_params, key) - for key in tp_pdf_monocens.TPCens_Params._fields - ] - ) - tp_pdf_sats_params = tp_pdf_sats.TP_Sats_Params( - *[getattr(diffmahpop_params, key) for key in tp_pdf_sats.TP_Sats_Params._fields] - ) - logm0_params = logm0_pop.LGM0Pop_Params( - *[getattr(diffmahpop_params, key) for key in logm0_pop.LGM0Pop_Params._fields] - ) - logtc_params = logtc_pop.Logtc_Params( - *[getattr(diffmahpop_params, key) for key in logtc_pop.Logtc_Params._fields] - ) - early_index_params = early_index_pop.EarlyIndex_Params( - *[ - getattr(diffmahpop_params, key) - for key in early_index_pop.EarlyIndex_Params._fields - ] - ) - late_index_params = late_index_pop.LateIndex_Params( - *[ - getattr(diffmahpop_params, key) - for key in late_index_pop.LateIndex_Params._fields - ] - ) - cov_params = covariance_kernels.CovParams( - *[ - getattr(diffmahpop_params, key) - for key in covariance_kernels.CovParams._fields - ] - ) - return ( - tp_pdf_monocens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) - - -@jjit -def get_component_model_u_params(diffmahpop_u_params): - tp_pdf_monocens_u_params = tp_pdf_monocens.TPCens_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in tp_pdf_monocens.TPCens_UParams._fields - ] - ) - tp_pdf_sats_u_params = tp_pdf_sats.TP_Sats_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in tp_pdf_sats.TP_Sats_UParams._fields - ] - ) - logm0_u_params = logm0_pop.LGM0Pop_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in logm0_pop.LGM0Pop_UParams._fields - ] - ) - logtc_u_params = logtc_pop.Logtc_UParams( - *[getattr(diffmahpop_u_params, key) for key in logtc_pop.Logtc_UParams._fields] - ) - early_index_u_params = early_index_pop.EarlyIndex_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in early_index_pop.EarlyIndex_UParams._fields - ] - ) - late_index_u_params = late_index_pop.LateIndex_UParams( - *[ - getattr(diffmahpop_u_params, key) - for key in late_index_pop.LateIndex_UParams._fields - ] - ) - cov_u_params = covariance_kernels.CovUParams( - *[ - getattr(diffmahpop_u_params, key) - for key in covariance_kernels.CovUParams._fields - ] - ) - - return ( - tp_pdf_monocens_u_params, - tp_pdf_sats_u_params, - logm0_u_params, - logtc_u_params, - early_index_u_params, - late_index_u_params, - cov_u_params, - ) - - -@jjit -def get_diffmahpop_params_from_u_params(diffmahpop_u_params): - component_model_u_params = get_component_model_u_params(diffmahpop_u_params) - tpc_u_params, tps_u_params, logm0_u_params = component_model_u_params[:3] - logtc_u_params = component_model_u_params[3] - early_index_u_params, late_index_u_params = component_model_u_params[4:6] - cov_u_params = component_model_u_params[6] - - tpc_params = tp_pdf_monocens.get_bounded_tp_cens_params(tpc_u_params) - tps_params = tp_pdf_sats.get_bounded_tp_sat_params(tps_u_params) - logm0_params = logm0_pop.get_bounded_m0pop_params(logm0_u_params) - logtc_params = logtc_pop.get_bounded_logtc_params(logtc_u_params) - early_index_params = early_index_pop.get_bounded_early_index_params( - early_index_u_params - ) - late_index_params = late_index_pop.get_bounded_late_index_params( - late_index_u_params - ) - cov_params = covariance_kernels.get_bounded_cov_params(cov_u_params) - - component_model_params = ( - tpc_params, - tps_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) - diffmahpop_params = DEFAULT_DIFFMAHPOP_PARAMS._make(DEFAULT_DIFFMAHPOP_PARAMS) - for params in component_model_params: - diffmahpop_params = diffmahpop_params._replace(**params._asdict()) - - return diffmahpop_params - - -@jjit -def get_diffmahpop_u_params_from_params(diffmahpop_params): - component_model_params = get_component_model_params(diffmahpop_params) - tpc_params, tps_params, logm0_params = component_model_params[:3] - logtc_params = component_model_params[3] - early_index_params, late_index_params = component_model_params[4:6] - cov_params = component_model_params[6] - - tpc_u_params = tp_pdf_monocens.get_unbounded_tp_cens_params(tpc_params) - tps_u_params = tp_pdf_sats.get_unbounded_tp_sat_params(tps_params) - logm0_u_params = logm0_pop.get_unbounded_m0pop_params(logm0_params) - logtc_u_params = logtc_pop.get_unbounded_logtc_params(logtc_params) - early_index_u_params = early_index_pop.get_unbounded_early_index_params( - early_index_params - ) - late_index_u_params = late_index_pop.get_unbounded_late_index_params( - late_index_params - ) - cov_u_params = covariance_kernels.get_unbounded_cov_params(cov_params) - - component_model_u_params = ( - tpc_u_params, - tps_u_params, - logm0_u_params, - logtc_u_params, - early_index_u_params, - late_index_u_params, - cov_u_params, - ) - diffmahpop_u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._make(DEFAULT_DIFFMAHPOP_U_PARAMS) - for u_params in component_model_u_params: - diffmahpop_u_params = diffmahpop_u_params._replace(**u_params._asdict()) - - return diffmahpop_u_params - - -@jjit -def _get_all_diffmahpop_params_from_varied( - varied_params, default_params=DEFAULT_DIFFMAHPOP_PARAMS -): - diffmahpop_params = default_params._replace(**varied_params._asdict()) - return diffmahpop_params - - -def get_varied_params_by_exclusion(all_params, excluded_pnames): - gen = zip(all_params._fields, all_params) - varied_pdict = OrderedDict( - [(name, float(x)) for (name, x) in gen if name not in excluded_pnames] - ) - VariedParams = namedtuple("VariedParams", varied_pdict.keys()) - varied_params = VariedParams(**varied_pdict) - return varied_params diff --git a/diffmah/diffmahpop_kernels/late_index_pop.py b/diffmah/diffmahpop_kernels/late_index_pop.py deleted file mode 100644 index 513f4e5..0000000 --- a/diffmah/diffmahpop_kernels/late_index_pop.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit -from jax import numpy as jnp -from jax import vmap - -from ..diffmah_kernels import MAH_PBOUNDS -from ..utils import _inverse_sigmoid, _sigmoid - -EPS = 1e-3 -LATE_INDEX_K = 1.0 - -LATE_INDEX_PDICT = OrderedDict( - late_index_x0=11.714, - late_index_ylo=0.196, - late_index_yhi=0.199, -) -LATE_INDEX_BOUNDS_PDICT = OrderedDict( - late_index_x0=(11.5, 14.0), late_index_ylo=(0.01, 0.2), late_index_yhi=(0.01, 0.2) -) - -LateIndex_Params = namedtuple("LateIndex_Params", LATE_INDEX_PDICT.keys()) -DEFAULT_LATE_INDEX_PARAMS = LateIndex_Params(**LATE_INDEX_PDICT) -LATE_INDEX_PBOUNDS = LateIndex_Params(**LATE_INDEX_BOUNDS_PDICT) -K_BOUNDING = 0.1 - - -@jjit -def _pred_late_index_kern(late_index_params, lgm_obs): - late_index_x0, late_index_ylo, late_index_yhi = late_index_params - late_index = _sigmoid( - lgm_obs, late_index_x0, LATE_INDEX_K, late_index_ylo, late_index_yhi - ) - ylo, yhi = MAH_PBOUNDS.late_index - late_index = jnp.clip(late_index, ylo + EPS, yhi - EPS) - return late_index - - -@jjit -def _get_bounded_late_index_param(u_param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _sigmoid(u_param, mid, K_BOUNDING, lo, hi) - - -@jjit -def _get_unbounded_late_index_param(param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _inverse_sigmoid(param, mid, K_BOUNDING, lo, hi) - - -_C = (0, 0) -_get_bounded_late_index_params_kern = jjit( - vmap(_get_bounded_late_index_param, in_axes=_C) -) -_get_unbounded_late_index_params_kern = jjit( - vmap(_get_unbounded_late_index_param, in_axes=_C) -) - - -@jjit -def get_bounded_late_index_params(u_params): - u_params = jnp.array( - [getattr(u_params, u_pname) for u_pname in _LATE_INDEX_UPNAMES] - ) - params = _get_bounded_late_index_params_kern( - jnp.array(u_params), jnp.array(LATE_INDEX_PBOUNDS) - ) - params = LateIndex_Params(*params) - return params - - -@jjit -def get_unbounded_late_index_params(params): - params = jnp.array([getattr(params, pname) for pname in LateIndex_Params._fields]) - u_params = _get_unbounded_late_index_params_kern( - jnp.array(params), jnp.array(LATE_INDEX_PBOUNDS) - ) - u_params = LateIndex_UParams(*u_params) - return u_params - - -_LATE_INDEX_UPNAMES = ["u_" + key for key in LateIndex_Params._fields] -LateIndex_UParams = namedtuple("LateIndex_UParams", _LATE_INDEX_UPNAMES) -DEFAULT_LATE_INDEX_U_PARAMS = get_unbounded_late_index_params(DEFAULT_LATE_INDEX_PARAMS) diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/__init__.py b/diffmah/diffmahpop_kernels/logm0_kernels/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py deleted file mode 100644 index 2f8abf7..0000000 --- a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c0_kernels.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit -from jax import numpy as jnp -from jax import value_and_grad, vmap - -from ...bfgs_wrapper import diffmah_fitter -from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid - -DEFAULT_LGM0POP_C0_PDICT = OrderedDict( - lgm0pop_c0_ytp=0.011, - lgm0pop_c0_ylo=-0.066, - lgm0pop_c0_clip_c0=0.602, - lgm0pop_c0_clip_c1=-0.090, - lgm0pop_c0_t_obs_x0=1.825, -) -LGM0Pop_C0_Params = namedtuple("LGM0Pop_C0_Params", DEFAULT_LGM0POP_C0_PDICT.keys()) -DEFAULT_LGM0POP_C0_PARAMS = LGM0Pop_C0_Params(**DEFAULT_LGM0POP_C0_PDICT) - -_C0_UPNAMES = ["u_" + key for key in LGM0Pop_C0_Params._fields] -LGM0Pop_C0_UParams = namedtuple("LGM0Pop_C0_UParams", _C0_UPNAMES) - -LGM0POP_C0_BOUNDS_DICT = OrderedDict( - lgm0pop_c0_ytp=(0.01, 0.4), - lgm0pop_c0_ylo=(-0.15, -0.05), - lgm0pop_c0_clip_c0=(0.5, 0.9), - lgm0pop_c0_clip_c1=(-0.1, -0.01), - lgm0pop_c0_t_obs_x0=(1.5, 6.0), -) -LGM0POP_C0_BOUNDS = LGM0Pop_C0_Params(**LGM0POP_C0_BOUNDS_DICT) - -XTP = 15 -GLOBAL_K = 0.25 -K_BOUNDING = 0.1 - - -@jjit -def _pred_c0_kern(params, t_obs, t_peak): - pred_c0 = _sig_slope( - t_obs, - XTP, - params.lgm0pop_c0_ytp, - params.lgm0pop_c0_t_obs_x0, - GLOBAL_K, - params.lgm0pop_c0_ylo, - 0.0, - ) - clip = params.lgm0pop_c0_clip_c0 + params.lgm0pop_c0_clip_c1 * t_peak - pred_c0 = jnp.clip(pred_c0, min=clip) - return pred_c0 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_kern_scalar(params, loss_data): - t_obs, t_peak, target_c0 = loss_data - pred_c0 = _pred_c0_kern(params, t_obs, t_peak) - return _mse(target_c0, pred_c0) - - -@jjit -def global_loss_kern(params, global_loss_data): - loss = 0.0 - for loss_data in global_loss_data: - loss = loss + _loss_kern_scalar(params, loss_data) - return loss - - -global_loss_and_grads_kern = jjit(value_and_grad(global_loss_kern)) - - -def fit_global_c0_model(global_loss_data, p_init=DEFAULT_LGM0POP_C0_PARAMS): - _res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data) - p_best, loss_best, fit_terminates, code_used = _res - return p_best, loss_best, fit_terminates, code_used - - -@jjit -def _get_bounded_c0_param(u_param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _sigmoid(u_param, mid, K_BOUNDING, lo, hi) - - -@jjit -def _get_unbounded_c0_param(param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _inverse_sigmoid(param, mid, K_BOUNDING, lo, hi) - - -_C = (0, 0) -_get_bounded_c0_params_kern = jjit(vmap(_get_bounded_c0_param, in_axes=_C)) -_get_unbounded_c0_params_kern = jjit(vmap(_get_unbounded_c0_param, in_axes=_C)) - - -@jjit -def get_bounded_c0_params(u_params): - u_params = jnp.array([getattr(u_params, u_pname) for u_pname in _C0_UPNAMES]) - params = _get_bounded_c0_params_kern( - jnp.array(u_params), jnp.array(LGM0POP_C0_BOUNDS) - ) - c0_params = LGM0Pop_C0_Params(*params) - return c0_params - - -@jjit -def get_unbounded_c0_params(params): - params = jnp.array([getattr(params, pname) for pname in LGM0Pop_C0_Params._fields]) - u_params = _get_unbounded_c0_params_kern( - jnp.array(params), jnp.array(LGM0POP_C0_BOUNDS) - ) - c0_u_params = LGM0Pop_C0_UParams(*u_params) - return c0_u_params - - -DEFAULT_LGM0POP_C0_U_PARAMS = LGM0Pop_C0_UParams( - *get_unbounded_c0_params(DEFAULT_LGM0POP_C0_PARAMS) -) diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py deleted file mode 100644 index db7b94d..0000000 --- a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_c1_kernels.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit -from jax import numpy as jnp -from jax import value_and_grad, vmap - -from ...bfgs_wrapper import diffmah_fitter -from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid - -DEFAULT_LGM0POP_C1_PDICT = OrderedDict( - lgm0pop_c1_ytp=0.002, - lgm0pop_c1_ylo=-0.024, - lgm0pop_c1_clip_x0=4.437, - lgm0pop_c1_clip_ylo=0.071, - lgm0pop_c1_clip_yhi=0.002, - lgm0pop_c1_t_obs_x0=5.950, -) -LGM0Pop_C1_Params = namedtuple("LGM0Pop_C1_Params", DEFAULT_LGM0POP_C1_PDICT.keys()) -DEFAULT_LGM0POP_C1_PARAMS = LGM0Pop_C1_Params(**DEFAULT_LGM0POP_C1_PDICT) - - -LGM0POP_C1_BOUNDS_DICT = OrderedDict( - lgm0pop_c1_ytp=(0.001, 0.1), - lgm0pop_c1_ylo=(-0.05, -0.001), - lgm0pop_c1_clip_x0=(4.0, 11.0), - lgm0pop_c1_clip_ylo=(0.02, 0.15), - lgm0pop_c1_clip_yhi=(0.001, 0.05), - lgm0pop_c1_t_obs_x0=(3.0, 10.0), -) -LGM0POP_C1_BOUNDS = LGM0Pop_C1_Params(**LGM0POP_C1_BOUNDS_DICT) - -_C1_UPNAMES = ["u_" + key for key in LGM0Pop_C1_Params._fields] -LGM0Pop_C1_UParams = namedtuple("LGM0Pop_C1_UParams", _C1_UPNAMES) - -XTP = 10.0 -GLOBAL_K = 0.25 -CLIP_TP_K = 1.0 -K_BOUNDING = 0.1 - - -@jjit -def _pred_c1_kern(params, t_obs, t_peak): - pred_c1 = _sig_slope( - t_obs, - XTP, - params.lgm0pop_c1_ytp, - params.lgm0pop_c1_t_obs_x0, - GLOBAL_K, - params.lgm0pop_c1_ylo, - 0.0, - ) - - clip = _sigmoid( - t_peak, - params.lgm0pop_c1_clip_x0, - CLIP_TP_K, - params.lgm0pop_c1_clip_ylo, - params.lgm0pop_c1_clip_yhi, - ) - pred_c1 = jnp.clip(pred_c1, min=clip) - return pred_c1 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_kern_scalar(params, loss_data): - t_obs, t_peak, target_c1 = loss_data - pred_c1 = _pred_c1_kern(params, t_obs, t_peak) - return _mse(target_c1, pred_c1) - - -@jjit -def global_loss_kern(params, global_loss_data): - loss = 0.0 - for loss_data in global_loss_data: - loss = loss + _loss_kern_scalar(params, loss_data) - return loss - - -global_loss_and_grads_kern = jjit(value_and_grad(global_loss_kern)) - - -def fit_global_c1_model(global_loss_data, p_init=DEFAULT_LGM0POP_C1_PARAMS): - _res = diffmah_fitter(global_loss_and_grads_kern, p_init, global_loss_data) - p_best, loss_best, fit_terminates, code_used = _res - return p_best, loss_best, fit_terminates, code_used - - -@jjit -def _get_bounded_c1_param(u_param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _sigmoid(u_param, mid, K_BOUNDING, lo, hi) - - -@jjit -def _get_unbounded_c1_param(param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _inverse_sigmoid(param, mid, K_BOUNDING, lo, hi) - - -_C = (0, 0) -_get_bounded_c1_params_kern = jjit(vmap(_get_bounded_c1_param, in_axes=_C)) -_get_unbounded_c1_params_kern = jjit(vmap(_get_unbounded_c1_param, in_axes=_C)) - - -@jjit -def get_bounded_c1_params(u_params): - u_params = jnp.array([getattr(u_params, u_pname) for u_pname in _C1_UPNAMES]) - params = _get_bounded_c1_params_kern( - jnp.array(u_params), jnp.array(LGM0POP_C1_BOUNDS) - ) - params = LGM0Pop_C1_Params(*params) - return params - - -@jjit -def get_unbounded_c1_params(params): - params = jnp.array([getattr(params, pname) for pname in LGM0Pop_C1_Params._fields]) - u_params = _get_unbounded_c1_params_kern( - jnp.array(params), jnp.array(LGM0POP_C1_BOUNDS) - ) - u_params = LGM0Pop_C1_UParams(*u_params) - return u_params - - -DEFAULT_LGM0POP_C1_U_PARAMS = LGM0Pop_C1_UParams( - *get_unbounded_c1_params(DEFAULT_LGM0POP_C1_PARAMS) -) diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_pop.py b/diffmah/diffmahpop_kernels/logm0_kernels/logm0_pop.py deleted file mode 100644 index 75e6e3c..0000000 --- a/diffmah/diffmahpop_kernels/logm0_kernels/logm0_pop.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit - -from . import logm0_c0_kernels, logm0_c1_kernels - -DEFAULT_LOGM0_PDICT = OrderedDict() -DEFAULT_LOGM0_PDICT.update(logm0_c0_kernels.DEFAULT_LGM0POP_C0_PDICT) -DEFAULT_LOGM0_PDICT.update(logm0_c1_kernels.DEFAULT_LGM0POP_C1_PDICT) - -LGM0Pop_Params = namedtuple("LGM0Pop_Params", DEFAULT_LOGM0_PDICT.keys()) -DEFAULT_LOGM0POP_PARAMS = LGM0Pop_Params(**DEFAULT_LOGM0_PDICT) - -_UPNAMES = ["u_" + key for key in LGM0Pop_Params._fields] -LGM0Pop_UParams = namedtuple("LGM0Pop_UParams", _UPNAMES) - -DEFAULT_LOGM0_BOUNDS_DICT = OrderedDict() -DEFAULT_LOGM0_BOUNDS_DICT.update(logm0_c0_kernels.LGM0POP_C0_BOUNDS_DICT) -DEFAULT_LOGM0_BOUNDS_DICT.update(logm0_c1_kernels.LGM0POP_C1_BOUNDS_DICT) -LGM0POP_BOUNDS = LGM0Pop_Params(**DEFAULT_LOGM0_BOUNDS_DICT) - -TP_LGMP = 12.0 - - -@jjit -def _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak): - c0_params = logm0_params[:5] - c0_params = logm0_c0_kernels.DEFAULT_LGM0POP_C0_PARAMS._make(c0_params) - c1_params = logm0_params[5:] - c1_params = logm0_c1_kernels.DEFAULT_LGM0POP_C1_PARAMS._make(c1_params) - c0 = logm0_c0_kernels._pred_c0_kern(c0_params, t_obs, t_peak) - c1 = logm0_c1_kernels._pred_c1_kern(c1_params, t_obs, t_peak) - delta_lgm = c0 + c1 * (lgm_obs - TP_LGMP) - return lgm_obs + delta_lgm - - -@jjit -def get_bounded_m0pop_params(u_params): - c0_u_params = [getattr(u_params, key) for key in logm0_c0_kernels._C0_UPNAMES] - c0_u_params = logm0_c0_kernels.LGM0Pop_C0_UParams(*c0_u_params) - c0_params = logm0_c0_kernels.get_bounded_c0_params(c0_u_params) - - c1_u_params = [getattr(u_params, key) for key in logm0_c1_kernels._C1_UPNAMES] - c1_u_params = logm0_c1_kernels.LGM0Pop_C1_UParams(*c1_u_params) - c1_params = logm0_c1_kernels.get_bounded_c1_params(c1_u_params) - - params = LGM0Pop_Params(*c0_params, *c1_params) - return params - - -@jjit -def get_unbounded_m0pop_params(params): - c0_pnames = logm0_c0_kernels.LGM0Pop_C0_Params._fields - c1_pnames = logm0_c1_kernels.LGM0Pop_C1_Params._fields - - c0_params = [getattr(params, key) for key in c0_pnames] - c0_params = logm0_c0_kernels.LGM0Pop_C0_Params(*c0_params) - c0_u_params = logm0_c0_kernels.get_unbounded_c0_params(c0_params) - - c1_params = [getattr(params, key) for key in c1_pnames] - c1_params = logm0_c1_kernels.LGM0Pop_C1_Params(*c1_params) - c1_u_params = logm0_c1_kernels.get_unbounded_c1_params(c1_params) - - u_params = LGM0Pop_UParams(*c0_u_params, *c1_u_params) - return u_params - - -DEFAULT_LOGM0POP_U_PARAMS = LGM0Pop_UParams( - *get_unbounded_m0pop_params(DEFAULT_LOGM0POP_PARAMS) -) diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/tests/__init__.py b/diffmah/diffmahpop_kernels/logm0_kernels/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c0_kernels.py b/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c0_kernels.py deleted file mode 100644 index b170acd..0000000 --- a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c0_kernels.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -""" - -import numpy as np - -from .. import logm0_c0_kernels as c0k - -TOL = 1e-2 - - -def test_param_u_param_names_propagate_properly(): - gen = zip( - c0k.DEFAULT_LGM0POP_C0_U_PARAMS._fields, c0k.DEFAULT_LGM0POP_C0_PARAMS._fields - ) - for u_key, key in gen: - assert u_key[:2] == "u_" - assert u_key[2:] == key - - inferred_default_params = c0k.get_bounded_c0_params(c0k.DEFAULT_LGM0POP_C0_U_PARAMS) - assert set(inferred_default_params._fields) == set( - c0k.DEFAULT_LGM0POP_C0_PARAMS._fields - ) - - inferred_default_u_params = c0k.get_unbounded_c0_params( - c0k.DEFAULT_LGM0POP_C0_PARAMS - ) - assert set(inferred_default_u_params._fields) == set( - c0k.DEFAULT_LGM0POP_C0_U_PARAMS._fields - ) - - -def test_get_bounded_params_fails_when_passing_params(): - try: - c0k.get_bounded_c0_params(c0k.DEFAULT_LGM0POP_C0_PARAMS) - raise NameError("get_bounded_c0_params should not accept params") - except AttributeError: - pass - - -def test_get_unbounded_params_fails_when_passing_u_params(): - try: - c0k.get_unbounded_c0_params(c0k.DEFAULT_LGM0POP_C0_U_PARAMS) - raise NameError("get_unbounded_c0_params should not accept u_params") - except AttributeError: - pass - - -def test_param_u_param_inversion(): - assert np.allclose( - c0k.DEFAULT_LGM0POP_C0_PARAMS, - c0k.get_bounded_c0_params(c0k.DEFAULT_LGM0POP_C0_U_PARAMS), - rtol=TOL, - ) - - inferred_default_params = c0k.get_bounded_c0_params( - c0k.get_unbounded_c0_params(c0k.DEFAULT_LGM0POP_C0_PARAMS) - ) - assert np.allclose(c0k.DEFAULT_LGM0POP_C0_PARAMS, inferred_default_params, rtol=TOL) - - -def test_default_params_are_in_bounds(): - for key in c0k.DEFAULT_LGM0POP_C0_PARAMS._fields: - val = getattr(c0k.DEFAULT_LGM0POP_C0_PARAMS, key) - bound = getattr(c0k.LGM0POP_C0_BOUNDS, key) - assert bound[0] < val < bound[1] diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c1_kernels.py b/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c1_kernels.py deleted file mode 100644 index 09db603..0000000 --- a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_c1_kernels.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -""" - -import numpy as np - -from .. import logm0_c1_kernels as c1k - -TOL = 1e-2 - - -def test_param_u_param_names_propagate_properly(): - gen = zip( - c1k.DEFAULT_LGM0POP_C1_U_PARAMS._fields, c1k.DEFAULT_LGM0POP_C1_PARAMS._fields - ) - for u_key, key in gen: - assert u_key[:2] == "u_" - assert u_key[2:] == key - - inferred_default_params = c1k.get_bounded_c1_params(c1k.DEFAULT_LGM0POP_C1_U_PARAMS) - assert set(inferred_default_params._fields) == set( - c1k.DEFAULT_LGM0POP_C1_PARAMS._fields - ) - - inferred_default_u_params = c1k.get_unbounded_c1_params( - c1k.DEFAULT_LGM0POP_C1_PARAMS - ) - assert set(inferred_default_u_params._fields) == set( - c1k.DEFAULT_LGM0POP_C1_U_PARAMS._fields - ) - - -def test_get_bounded_params_fails_when_passing_params(): - try: - c1k.get_bounded_c1_params(c1k.DEFAULT_LGM0POP_C1_PARAMS) - raise NameError("get_bounded_c0_params should not accept params") - except AttributeError: - pass - - -def test_get_unbounded_params_fails_when_passing_u_params(): - try: - c1k.get_unbounded_c0_params(c1k.DEFAULT_LGM0POP_C1_U_PARAMS) - raise NameError("get_unbounded_c1_params should not accept u_params") - except AttributeError: - pass - - -def test_param_u_param_inversion(): - assert np.allclose( - c1k.DEFAULT_LGM0POP_C1_PARAMS, - c1k.get_bounded_c1_params(c1k.DEFAULT_LGM0POP_C1_U_PARAMS), - rtol=TOL, - ) - - inferred_default_params = c1k.get_bounded_c1_params( - c1k.get_unbounded_c1_params(c1k.DEFAULT_LGM0POP_C1_PARAMS) - ) - assert np.allclose(c1k.DEFAULT_LGM0POP_C1_PARAMS, inferred_default_params, rtol=TOL) - - -def test_default_params_are_in_bounds(): - for key in c1k.DEFAULT_LGM0POP_C1_PARAMS._fields: - val = getattr(c1k.DEFAULT_LGM0POP_C1_PARAMS, key) - bound = getattr(c1k.LGM0POP_C1_BOUNDS, key) - assert bound[0] < val < bound[1], key diff --git a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py b/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py deleted file mode 100644 index b09c0f1..0000000 --- a/diffmah/diffmahpop_kernels/logm0_kernels/tests/test_logm0_pop.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from .. import logm0_pop as m0pop - -TOL = 1e-2 - - -def test_param_u_param_names_propagate_properly(): - gen = zip( - m0pop.DEFAULT_LOGM0POP_U_PARAMS._fields, - m0pop.DEFAULT_LOGM0POP_PARAMS._fields, - ) - for u_key, key in gen: - assert u_key[:2] == "u_" - assert u_key[2:] == key - - inferred_default_params = m0pop.get_bounded_m0pop_params( - m0pop.DEFAULT_LOGM0POP_U_PARAMS - ) - assert set(inferred_default_params._fields) == set( - m0pop.DEFAULT_LOGM0POP_PARAMS._fields - ) - - inferred_default_u_params = m0pop.get_unbounded_m0pop_params( - m0pop.DEFAULT_LOGM0POP_PARAMS - ) - assert set(inferred_default_u_params._fields) == set( - m0pop.DEFAULT_LOGM0POP_U_PARAMS._fields - ) - - -def test_get_bounded_params_fails_when_passing_params(): - try: - m0pop.get_bounded_m0pop_params(m0pop.DEFAULT_LGM0POP_PARAMS) - raise NameError("get_bounded_m0pop_params should not accept params") - except AttributeError: - pass - - -def test_get_unbounded_params_fails_when_passing_u_params(): - try: - m0pop.get_unbounded_m0pop_params(m0pop.DEFAULT_LOGM0POP_U_PARAMS) - raise NameError("get_unbounded_m0pop_params should not accept u_params") - except AttributeError: - pass - - -def test_param_u_param_inversion(): - assert np.allclose( - m0pop.DEFAULT_LOGM0POP_PARAMS, - m0pop.get_bounded_m0pop_params(m0pop.DEFAULT_LOGM0POP_U_PARAMS), - rtol=TOL, - ) - - inferred_default_params = m0pop.get_bounded_m0pop_params( - m0pop.get_unbounded_m0pop_params(m0pop.DEFAULT_LOGM0POP_PARAMS) - ) - assert np.allclose(m0pop.DEFAULT_LOGM0POP_PARAMS, inferred_default_params, rtol=TOL) - - -def test_default_params_are_in_bounds(): - for key in m0pop.DEFAULT_LOGM0POP_PARAMS._fields: - val = getattr(m0pop.DEFAULT_LOGM0POP_PARAMS, key) - bound = getattr(m0pop.LGM0POP_BOUNDS, key) - assert bound[0] < val < bound[1] - - -def test_pred_logm0_kern(): - ran_key = jran.key(0) - n_tests = 1_000 - for __ in range(n_tests): - lgm_key, t_obs_key, t_peak_key, ran_key = jran.split(ran_key, 4) - lgm_obs = jran.uniform(lgm_key, minval=5, maxval=16, shape=()) - t_obs = jran.uniform(t_obs_key, minval=1, maxval=20, shape=()) - t_peak = jran.uniform(t_peak_key, minval=1.5, maxval=20, shape=()) - lgm0 = m0pop._pred_logm0_kern( - m0pop.DEFAULT_LOGM0POP_PARAMS, lgm_obs, t_obs, t_peak - ) - assert lgm0.shape == () - assert np.isfinite(lgm0) - assert lgm0 > 0 - assert lgm0 < 20 diff --git a/diffmah/diffmahpop_kernels/logtc_pop.py b/diffmah/diffmahpop_kernels/logtc_pop.py deleted file mode 100644 index ce21200..0000000 --- a/diffmah/diffmahpop_kernels/logtc_pop.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -""" - -from collections import OrderedDict, namedtuple - -from jax import jit as jjit -from jax import numpy as jnp -from jax import value_and_grad, vmap - -from ..diffmah_kernels import MAH_PBOUNDS -from ..utils import _inverse_sigmoid, _sig_slope, _sigmoid - -EPS = 1e-3 -K_BOUNDING = 0.1 -LOGTC_PDICT = OrderedDict( - lgm_c0_tp_ytp_tobs_c0=0.213, - lgm_c0_tp_ytp_tobs_c1=0.030, - lgm_c0_tp_ylo=0.076, - lgm_c0_tp_yhi=-0.082, - lgm_c1=0.092, - logtc_c0_ss_x0=11.444, -) -LOGTC_BOUNDS_PDICT = OrderedDict( - lgm_c0_tp_ytp_tobs_c0=(0.2, 0.9), - lgm_c0_tp_ytp_tobs_c1=(0.0, 0.05), - lgm_c0_tp_ylo=(0.02, 0.15), - lgm_c0_tp_yhi=(-0.25, 0.0), - lgm_c1=(0.02, 0.15), - logtc_c0_ss_x0=(11.0, 13.0), -) -Logtc_Params = namedtuple("Logtc_Params", LOGTC_PDICT.keys()) -DEFAULT_LOGTC_PARAMS = Logtc_Params(**LOGTC_PDICT) -LOGTC_PBOUNDS = Logtc_Params(**LOGTC_BOUNDS_PDICT) - -K_BOUNDING = 0.1 -TAUC_LGMP = 12.0 -C0_SS_XTP = 10.0 -C0_SS_K = 0.5 - - -@jjit -def _pred_logtc_kern(params, lgm_obs, t_obs, t_peak): - lgm_c0 = _get_c0(params, t_peak, t_obs) - logtc = lgm_c0 + params.lgm_c1 * (lgm_obs - TAUC_LGMP) - ylo, yhi = MAH_PBOUNDS.logtc - logtc = jnp.clip(logtc, ylo + EPS, yhi - EPS) - return logtc - - -@jjit -def _get_c0(params, t_peak, t_obs): - ytp = params.lgm_c0_tp_ytp_tobs_c0 + params.lgm_c0_tp_ytp_tobs_c1 * t_obs - ylo, yhi = params.lgm_c0_tp_ylo, params.lgm_c0_tp_yhi - c0 = _sig_slope(t_peak, C0_SS_XTP, ytp, params.logtc_c0_ss_x0, C0_SS_K, ylo, yhi) - return c0 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_kern(u_params, loss_data): - u_params = Logtc_UParams(*u_params) - params = get_bounded_logtc_params(u_params) - lgm_obs, t_obs, t_peak, logtc_target = loss_data - logtc_pred = _pred_logtc_kern(params, lgm_obs, t_obs, t_peak) - return _mse(logtc_pred, logtc_target) - - -loss_and_grads_kern = jjit(value_and_grad(_loss_kern)) - - -@jjit -def _get_bounded_logtc_param(u_param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _sigmoid(u_param, mid, K_BOUNDING, lo, hi) - - -@jjit -def _get_unbounded_logtc_param(param, bound): - lo, hi = bound - mid = 0.5 * (lo + hi) - return _inverse_sigmoid(param, mid, K_BOUNDING, lo, hi) - - -_C = (0, 0) -_get_bounded_logtc_params_kern = jjit(vmap(_get_bounded_logtc_param, in_axes=_C)) -_get_unbounded_logtc_params_kern = jjit(vmap(_get_unbounded_logtc_param, in_axes=_C)) - - -@jjit -def get_bounded_logtc_params(u_params): - u_params = jnp.array([getattr(u_params, u_pname) for u_pname in _LOGTC_UPNAMES]) - params = _get_bounded_logtc_params_kern( - jnp.array(u_params), jnp.array(LOGTC_PBOUNDS) - ) - params = Logtc_Params(*params) - return params - - -@jjit -def get_unbounded_logtc_params(params): - params = jnp.array([getattr(params, pname) for pname in Logtc_Params._fields]) - u_params = _get_unbounded_logtc_params_kern( - jnp.array(params), jnp.array(LOGTC_PBOUNDS) - ) - u_params = Logtc_UParams(*u_params) - return u_params - - -_LOGTC_UPNAMES = ["u_" + key for key in Logtc_Params._fields] -Logtc_UParams = namedtuple("Logtc_UParams", _LOGTC_UPNAMES) -DEFAULT_LOGTC_U_PARAMS = get_unbounded_logtc_params(DEFAULT_LOGTC_PARAMS) diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels.py deleted file mode 100644 index 8fc72d6..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -""" - -from functools import partial - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_halopop, - mah_singlehalo, -) -from . import ftpt0_cens -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern -from .t_peak_kernels.tp_pdf_cens import mc_tpeak_singlecen - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - t_0 = 10**lgt0 - model_params = get_component_model_params(diffmahpop_params) - ( - ftpt0_cens_params, - tp_pdf_cens_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - ftpt0 = ftpt0_cens._ftpt0_kernel(ftpt0_cens_params, lgm_obs, t_obs) - - tpc_key, ran_key = jran.split(ran_key, 2) - - lgm_obs = lgm_obs - t_obs = t_obs - args = tp_pdf_cens_params, tpc_key, lgm_obs, t_obs, t_0 - t_peak = mc_tpeak_singlecen(*args) - - ftpt0_key, ran_key = jran.split(ran_key, 2) - mc_tpt0 = jran.uniform(ftpt0_key, shape=()) < ftpt0 - - logm0_tpt0 = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_0) - logm0_tp = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak) - - logtc_tpt0 = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_0) - logtc_tp = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak) - - early_index_tpt0 = _pred_early_index_kern(early_index_params, lgm_obs, t_obs, t_0) - early_index_tp = _pred_early_index_kern(early_index_params, lgm_obs, t_obs, t_peak) - - late_index_tpt0 = _pred_late_index_kern(late_index_params, lgm_obs) - late_index_tp = _pred_late_index_kern(late_index_params, lgm_obs) - - dmah_tpt0 = DiffmahParams(logm0_tpt0, logtc_tpt0, early_index_tpt0, late_index_tpt0) - dmah_tp = DiffmahParams(logm0_tp, logtc_tp, early_index_tp, late_index_tp) - - return dmah_tpt0, dmah_tp, t_peak, ftpt0, mc_tpt0 - - -@jjit -def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - dmah_tpt0, dmah_tp, t_peak, ftpt0, mc_tpt0 = mc_mean_diffmah_params( - diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 - ) - u_dmah_tpt0 = get_unbounded_mah_params(dmah_tpt0) - u_dmah_tp = get_unbounded_mah_params(dmah_tp) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, tpt0_key, tp_key = jran.split(ran_key, 3) - ran_diffmah_u_params_tpt0 = jran.multivariate_normal( - tpt0_key, jnp.array(u_dmah_tpt0), cov, shape=() - ) - ran_diffmah_u_params_tp = jran.multivariate_normal( - tp_key, jnp.array(u_dmah_tp), cov, shape=() - ) - ran_diffmah_u_params_tpt0 = DiffmahUParams(*ran_diffmah_u_params_tpt0) - ran_diffmah_u_params_tp = DiffmahUParams(*ran_diffmah_u_params_tp) - - mah_params_tpt0 = get_bounded_mah_params(ran_diffmah_u_params_tpt0) - mah_params_tp = get_bounded_mah_params(ran_diffmah_u_params_tp) - return mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 - - -_A = (None, 0, 0, 0, None) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_singlecen, in_axes=_A)) - - -@jjit -def mc_diffmah_params_cenpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) - mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - dmhdt_tpt0, log_mah_tpt0 = mah_singlehalo(mah_params_tpt0, tarr, 10**lgt0, lgt0) - dmhdt_tp, log_mah_tp = mah_singlehalo(mah_params_tp, tarr, t_peak, lgt0) - _ret = ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) - return _ret - - -_V = (None, None, 0, 0, 0, None) -_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V)) - - -@partial(jjit, static_argnames=["n_mc"]) -def _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN -): - zz = jnp.zeros(n_mc) - ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_singlecen_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_cenpop(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - _res = _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - tpt0 = jnp.zeros_like(t_peak) + 10**lgt0 - dmhdt_tpt0, log_mah_tpt0 = mah_halopop(mah_params_tpt0, tarr, tpt0, lgt0) - dmhdt_tp, log_mah_tp = mah_halopop(mah_params_tp, tarr, t_peak, lgt0) - _ret = ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) - return _ret - - -@jjit -def predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) = _res - - f = ftpt0.reshape((-1, 1)) - mean_log_mah = jnp.mean(f * log_mah_tpt0 + (1 - f) * log_mah_tp, axis=0) - std_log_mah = jnp.std(f * log_mah_tpt0 + (1 - f) * log_mah_tp, axis=0) - - return mean_log_mah, std_log_mah diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_cens.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_cens.py deleted file mode 100644 index 2c79485..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_cens.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -""" - -from functools import partial - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_halopop, - mah_singlehalo, -) -from . import ftpt0_cens -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params_censat import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern -from .t_peak_kernels.tp_pdf_cens import mc_tpeak_singlecen - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - t_0 = 10**lgt0 - model_params = get_component_model_params(diffmahpop_params) - ( - ftpt0_cens_params, - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - ftpt0 = ftpt0_cens._ftpt0_kernel(ftpt0_cens_params, lgm_obs, t_obs) - - tpc_key, ran_key = jran.split(ran_key, 2) - - lgm_obs = lgm_obs - t_obs = t_obs - args = tp_pdf_cens_params, tpc_key, lgm_obs, t_obs, t_0 - t_peak = mc_tpeak_singlecen(*args) - - ftpt0_key, ran_key = jran.split(ran_key, 2) - mc_tpt0 = jran.uniform(ftpt0_key, shape=()) < ftpt0 - - logm0_tpt0 = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_0) - logm0_tp = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak) - - logtc_tpt0 = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_0) - logtc_tp = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak) - - early_index_tpt0 = _pred_early_index_kern(early_index_params, lgm_obs, t_obs, t_0) - early_index_tp = _pred_early_index_kern(early_index_params, lgm_obs, t_obs, t_peak) - - late_index_tpt0 = _pred_late_index_kern(late_index_params, lgm_obs) - late_index_tp = _pred_late_index_kern(late_index_params, lgm_obs) - - dmah_tpt0 = DiffmahParams(logm0_tpt0, logtc_tpt0, early_index_tpt0, late_index_tpt0) - dmah_tp = DiffmahParams(logm0_tp, logtc_tp, early_index_tp, late_index_tp) - - return dmah_tpt0, dmah_tp, t_peak, ftpt0, mc_tpt0 - - -@jjit -def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - dmah_tpt0, dmah_tp, t_peak, ftpt0, mc_tpt0 = mc_mean_diffmah_params( - diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 - ) - u_dmah_tpt0 = get_unbounded_mah_params(dmah_tpt0) - u_dmah_tp = get_unbounded_mah_params(dmah_tp) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, tpt0_key, tp_key = jran.split(ran_key, 3) - ran_diffmah_u_params_tpt0 = jran.multivariate_normal( - tpt0_key, jnp.array(u_dmah_tpt0), cov, shape=() - ) - ran_diffmah_u_params_tp = jran.multivariate_normal( - tp_key, jnp.array(u_dmah_tp), cov, shape=() - ) - ran_diffmah_u_params_tpt0 = DiffmahUParams(*ran_diffmah_u_params_tpt0) - ran_diffmah_u_params_tp = DiffmahUParams(*ran_diffmah_u_params_tp) - - mah_params_tpt0 = get_bounded_mah_params(ran_diffmah_u_params_tpt0) - mah_params_tp = get_bounded_mah_params(ran_diffmah_u_params_tp) - return mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 - - -_A = (None, 0, 0, 0, None) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_singlecen, in_axes=_A)) - - -@jjit -def mc_diffmah_params_cenpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) - mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - dmhdt_tpt0, log_mah_tpt0 = mah_singlehalo(mah_params_tpt0, tarr, 10**lgt0, lgt0) - dmhdt_tp, log_mah_tp = mah_singlehalo(mah_params_tp, tarr, t_peak, lgt0) - _ret = ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) - return _ret - - -_V = (None, None, 0, 0, 0, None) -_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V)) - - -@partial(jjit, static_argnames=["n_mc"]) -def _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN -): - zz = jnp.zeros(n_mc) - ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_singlecen_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_cenpop(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - _res = _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - tpt0 = jnp.zeros_like(t_peak) + 10**lgt0 - dmhdt_tpt0, log_mah_tpt0 = mah_halopop(mah_params_tpt0, tarr, tpt0, lgt0) - dmhdt_tp, log_mah_tp = mah_halopop(mah_params_tp, tarr, t_peak, lgt0) - _ret = ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) - return _ret - - -@jjit -def predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) = _res - - f = ftpt0.reshape((-1, 1)) - mean_log_mah = jnp.mean(f * log_mah_tpt0 + (1 - f) * log_mah_tp, axis=0) - std_log_mah = jnp.std(f * log_mah_tpt0 + (1 - f) * log_mah_tp, axis=0) - - frac_peaked_tpt0 = jnp.mean(f * dmhdt_tpt0 == 0, axis=0) - frac_peaked_tp = jnp.mean((1 - f) * dmhdt_tp == 0, axis=0) - frac_peaked = frac_peaked_tpt0 + frac_peaked_tp - - return mean_log_mah, std_log_mah, frac_peaked diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_censat.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_censat.py deleted file mode 100644 index ae96083..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_censat.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -""" - -from functools import partial - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_singlehalo, -) -from . import ftpt0_cens -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params_censat import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern -from .t_peak_kernels.tp_pdf_cens import mc_tpeak_singlecen -from .t_peak_kernels.tp_pdf_sats import mc_tpeak_singlesat - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - t_0 = 10**lgt0 - model_params = get_component_model_params(diffmahpop_params) - ( - ftpt0_cens_params, - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - frac_tpt0_cens = ftpt0_cens._ftpt0_kernel(ftpt0_cens_params, lgm_obs, t_obs) - - tpc_key, tps_key, ran_key = jran.split(ran_key, 3) - - args = tp_pdf_cens_params, tpc_key, lgm_obs, t_obs, t_0 - t_peak_cens = mc_tpeak_singlecen(*args) - - t_peak_sats = mc_tpeak_singlesat(tp_pdf_sats_params, ran_key, lgm_obs, t_obs) - - ftpt0_key, ran_key = jran.split(ran_key, 2) - mc_tpt0_cens = jran.uniform(ftpt0_key, shape=()) < frac_tpt0_cens - - logm0_tpt0_cens = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_0) - logm0_tp_cens = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak_cens) - logm0_sats = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak_sats) - - logtc_tpt0_cens = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_0) - logtc_tp_cens = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak_cens) - logtc_sats = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak_sats) - - early_index_tpt0_cens = _pred_early_index_kern( - early_index_params, lgm_obs, t_obs, t_0 - ) - early_index_tp_cens = _pred_early_index_kern( - early_index_params, lgm_obs, t_obs, t_peak_cens - ) - early_index_sats = _pred_early_index_kern( - early_index_params, lgm_obs, t_obs, t_peak_sats - ) - - late_index_tpt0_cens = _pred_late_index_kern(late_index_params, lgm_obs) - late_index_tp_cens = _pred_late_index_kern(late_index_params, lgm_obs) - late_index_sats = _pred_late_index_kern(late_index_params, lgm_obs) - - dmah_tpt0_cens = DiffmahParams( - logm0_tpt0_cens, logtc_tpt0_cens, early_index_tpt0_cens, late_index_tpt0_cens - ) - dmah_tp_cens = DiffmahParams( - logm0_tp_cens, logtc_tp_cens, early_index_tp_cens, late_index_tp_cens - ) - - dmah_sats = DiffmahParams(logm0_sats, logtc_sats, early_index_sats, late_index_sats) - - return ( - dmah_tpt0_cens, - dmah_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - t_peak_sats, - dmah_sats, - ) - - -@jjit -def mc_diffmah_params_single_censat(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) - ( - dmah_tpt0_cens, - dmah_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - t_peak_sats, - dmah_sats, - ) = _res - u_dmah_tpt0_cens = get_unbounded_mah_params(dmah_tpt0_cens) - u_dmah_tp_cens = get_unbounded_mah_params(dmah_sats) - - u_dmah_sats = get_unbounded_mah_params(dmah_tp_cens) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, tpt0_cens_key, tp_cens_key = jran.split(ran_key, 3) - ran_diffmah_u_params_tpt0_cens = jran.multivariate_normal( - tpt0_cens_key, jnp.array(u_dmah_tpt0_cens), cov, shape=() - ) - ran_diffmah_u_params_tp_cens = jran.multivariate_normal( - tp_cens_key, jnp.array(u_dmah_tp_cens), cov, shape=() - ) - sats_key = tp_cens_key - ran_diffmah_u_params_sats = jran.multivariate_normal( - sats_key, jnp.array(u_dmah_sats), cov, shape=() - ) - - ran_diffmah_u_params_tpt0_cens = DiffmahUParams(*ran_diffmah_u_params_tpt0_cens) - ran_diffmah_u_params_tp_cens = DiffmahUParams(*ran_diffmah_u_params_tp_cens) - ran_diffmah_u_params_sats = DiffmahUParams(*ran_diffmah_u_params_sats) - - mah_params_tpt0_cens = get_bounded_mah_params(ran_diffmah_u_params_tpt0_cens) - mah_params_tp_cens = get_bounded_mah_params(ran_diffmah_u_params_tp_cens) - mah_params_sats = get_bounded_mah_params(ran_diffmah_u_params_sats) - - return ( - mah_params_tpt0_cens, - mah_params_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - t_peak_sats, - mah_params_sats, - ) - - -_A = (None, 0, 0, 0, None) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_single_censat, in_axes=_A)) - - -@jjit -def mc_diffmah_params_cenpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_single_censat(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_single_censat( - diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 - ) - ( - mah_params_tpt0_cens, - mah_params_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - t_peak_sats, - mah_params_sats, - ) = _res - dmhdt_tpt0_cens, log_mah_tpt0_cens = mah_singlehalo( - mah_params_tpt0_cens, tarr, 10**lgt0, lgt0 - ) - dmhdt_tp_cens, log_mah_tp_cens = mah_singlehalo( - mah_params_tp_cens, tarr, t_peak_cens, lgt0 - ) - dmhdt_sats, log_mah_sats = mah_singlehalo( - mah_params_tp_cens, tarr, t_peak_sats, lgt0 - ) - _ret = ( - mah_params_tpt0_cens, - mah_params_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - dmhdt_tpt0_cens, - log_mah_tpt0_cens, - dmhdt_tp_cens, - log_mah_tp_cens, - dmhdt_sats, - log_mah_sats, - ) - return _ret - - -_V = (None, None, 0, 0, 0, None) -_mc_diffmah_single_censat_vmap_kern = jjit(vmap(_mc_diffmah_single_censat, in_axes=_V)) - - -@partial(jjit, static_argnames=["n_mc"]) -def _mc_diffmah_halo_sample_censat( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN -): - zz = jnp.zeros(n_mc) - ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_single_censat_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 - ) - - -@jjit -def predict_mah_moments_singlebin_censat( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample_censat( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - ( - mah_params_tpt0_cens, - mah_params_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - dmhdt_tpt0_cens, - log_mah_tpt0_cens, - dmhdt_tp_cens, - log_mah_tp_cens, - dmhdt_sats, - log_mah_sats, - ) = _res - - f = frac_tpt0_cens.reshape((-1, 1)) - mean_log_mah_cens = jnp.mean( - f * log_mah_tpt0_cens + (1 - f) * log_mah_tp_cens, axis=0 - ) - std_log_mah_cens = jnp.std( - f * log_mah_tpt0_cens + (1 - f) * log_mah_tp_cens, axis=0 - ) - - mean_log_mah_sats = jnp.mean(f * log_mah_sats + (1 - f) * log_mah_sats, axis=0) - std_log_mah_sats = jnp.std(f * log_mah_sats + (1 - f) * log_mah_sats, axis=0) - - return mean_log_mah_cens, std_log_mah_cens, mean_log_mah_sats, std_log_mah_sats diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens.py deleted file mode 100644 index 63a0fca..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -""" - -from functools import partial - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_singlehalo, -) -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params_monocensat import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern -from .t_peak_kernels.tp_pdf_monocens import mc_tpeak_singlecen - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - t_0 = 10**lgt0 - model_params = get_component_model_params(diffmahpop_params) - ( - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - - tpc_key, ran_key = jran.split(ran_key, 2) - - lgm_obs = lgm_obs - t_obs = t_obs - args = tp_pdf_cens_params, lgm_obs, tpc_key, t_0 - t_peak = mc_tpeak_singlecen(*args) - - logm0 = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak) - logtc = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak) - early_index = _pred_early_index_kern(early_index_params, lgm_obs, t_obs, t_peak) - late_index = _pred_late_index_kern(late_index_params, lgm_obs) - mah_params = DiffmahParams(logm0, logtc, early_index, late_index) - - return mah_params, t_peak - - -@jjit -def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - mean_mah_params, t_peak = mc_mean_diffmah_params( - diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 - ) - mean_mah_u_params = get_unbounded_mah_params(mean_mah_params) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, p_key = jran.split(ran_key, 2) - ran_diffmah_u_params_tp = jran.multivariate_normal( - p_key, jnp.array(mean_mah_u_params), cov, shape=() - ) - ran_diffmah_u_params = DiffmahUParams(*ran_diffmah_u_params_tp) - - mah_params = get_bounded_mah_params(ran_diffmah_u_params) - return mah_params, t_peak - - -_A = (None, 0, 0, 0, None) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_singlecen, in_axes=_A)) - - -@jjit -def mc_diffmah_params_cenpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) - mah_params, t_peak = _res - dmhdt, log_mah = mah_singlehalo(mah_params, tarr, t_peak, lgt0) - _ret = (mah_params, t_peak, dmhdt, log_mah) - return _ret - - -_V = (None, None, 0, 0, 0, None) -_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V)) - - -@partial(jjit, static_argnames=["n_mc"]) -def _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN -): - zz = jnp.zeros(n_mc) - ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_singlecen_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 - ) - - -@jjit -def predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mah_params, t_peak, dmhdt, log_mah = _res - - mean_log_mah = jnp.mean(log_mah, axis=0) - std_log_mah = jnp.std(log_mah, axis=0) - - frac_peaked = jnp.mean(dmhdt == 0, axis=0) - - return mean_log_mah, std_log_mah, frac_peaked diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens_fixed_tpeak.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens_fixed_tpeak.py deleted file mode 100644 index 562e1f0..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monocens_fixed_tpeak.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_singlehalo, -) -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params_monocensat import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, t_peak): - model_params = get_component_model_params(diffmahpop_params) - ( - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - - lgm_obs = lgm_obs - t_obs = t_obs - - logm0 = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak) - logtc = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak) - early_index = _pred_early_index_kern(early_index_params, lgm_obs, t_obs, t_peak) - late_index = _pred_late_index_kern(late_index_params, lgm_obs) - mah_params = DiffmahParams(logm0, logtc, early_index, late_index) - - return mah_params - - -@jjit -def mc_diffmah_params_singlecen(diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key): - mean_mah_params = mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, t_peak) - mean_mah_u_params = get_unbounded_mah_params(mean_mah_params) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, p_key = jran.split(ran_key, 2) - ran_diffmah_u_params_tp = jran.multivariate_normal( - p_key, jnp.array(mean_mah_u_params), cov, shape=() - ) - ran_diffmah_u_params = DiffmahUParams(*ran_diffmah_u_params_tp) - - mah_params = get_bounded_mah_params(ran_diffmah_u_params) - return mah_params - - -_A = (None, 0, 0, 0, 0) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_singlecen, in_axes=_A)) - - -@jjit -def mc_diffmah_params_cenpop(diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, t_peak, lgt0 - ) - - -@jjit -def _mc_diffmah_singlecen( - diffmahpop_params, tarr, lgm_obs, t_obs, t_peak, ran_key, lgt0 -): - mah_params = mc_diffmah_params_singlecen( - diffmahpop_params, lgm_obs, t_obs, t_peak, ran_key - ) - dmhdt, log_mah = mah_singlehalo(mah_params, tarr, t_peak, lgt0) - _ret = (mah_params, dmhdt, log_mah) - return _ret - - -_V = (None, None, 0, 0, 0, 0, None) -_mc_diffmah_singlecen_vmap_kern = jjit(vmap(_mc_diffmah_singlecen, in_axes=_V)) - - -@jjit -def _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, t_peak_sample, ran_key, lgt0 -): - zz = jnp.zeros_like(t_peak_sample) - ran_keys = jran.split(ran_key, zz.size) - return _mc_diffmah_singlecen_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, t_peak_sample, ran_keys, lgt0 - ) - - -@jjit -def predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, t_peak_sample, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, t_peak_sample, ran_key, lgt0 - ) - mah_params, dmhdt, log_mah = _res - - mean_log_mah = jnp.mean(log_mah, axis=0) - std_log_mah = jnp.std(log_mah, axis=0) - - frac_peaked = jnp.mean(dmhdt == 0, axis=0) - - return mean_log_mah, std_log_mah, frac_peaked diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monosats.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monosats.py deleted file mode 100644 index ea31983..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_monosats.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -""" - -from functools import partial - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_singlehalo, -) -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params_monocensat import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern -from .t_peak_kernels.tp_pdf_sats import mc_tpeak_singlesat - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key): - model_params = get_component_model_params(diffmahpop_params) - ( - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - - tpc_key, ran_key = jran.split(ran_key, 2) - - lgm_obs = lgm_obs - t_obs = t_obs - args = tp_pdf_sats_params, tpc_key, lgm_obs, t_obs - t_peak_sats = mc_tpeak_singlesat(*args) - - logm0_tp = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak_sats) - - logtc_tp = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak_sats) - - early_index_tp = _pred_early_index_kern( - early_index_params, lgm_obs, t_obs, t_peak_sats - ) - - late_index_tp = _pred_late_index_kern(late_index_params, lgm_obs) - - dmah_sats = DiffmahParams(logm0_tp, logtc_tp, early_index_tp, late_index_tp) - - return dmah_sats, t_peak_sats - - -@jjit -def mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key): - dmah_sats, t_peak_sats = mc_mean_diffmah_params( - diffmahpop_params, lgm_obs, t_obs, ran_key - ) - u_dmah_sats = get_unbounded_mah_params(dmah_sats) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, tpt0_key, tp_key = jran.split(ran_key, 3) - - ran_diffmah_u_params_tp = jran.multivariate_normal( - tp_key, jnp.array(u_dmah_sats), cov, shape=() - ) - ran_diffmah_u_params = DiffmahUParams(*ran_diffmah_u_params_tp) - - mah_params = get_bounded_mah_params(ran_diffmah_u_params) - return mah_params, t_peak_sats - - -_A = (None, 0, 0, 0) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_singlesat, in_axes=_A)) - - -@jjit -def mc_diffmah_params_satpop(diffmahpop_params, lgm_obs, t_obs, ran_key): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern(diffmahpop_params, lgm_obs, t_obs, ran_keys) - - -@jjit -def _mc_diffmah_singlesat(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key) - mah_params, t_peak_sats = _res - dmhdt_sats, log_mah_sats = mah_singlehalo(mah_params, tarr, t_peak_sats, lgt0) - return mah_params, t_peak_sats, dmhdt_sats, log_mah_sats - - -_V = (None, None, 0, 0, 0, None) -_mc_diffmah_singlesat_vmap_kern = jjit(vmap(_mc_diffmah_singlesat, in_axes=_V)) - - -@partial(jjit, static_argnames=["n_mc"]) -def _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN -): - zz = jnp.zeros(n_mc) - ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_singlesat_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 - ) - - -@jjit -def predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mah_params, t_peak_sats, dmhdt, log_mah = _res - - mean_log_mah = jnp.mean(log_mah, axis=0) - std_log_mah = jnp.std(log_mah, axis=0) - frac_peaked = jnp.mean(dmhdt == 0, axis=0) - - return mean_log_mah, std_log_mah, frac_peaked diff --git a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_sats.py b/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_sats.py deleted file mode 100644 index 1be1b15..0000000 --- a/diffmah/diffmahpop_kernels/mc_diffmahpop_kernels_sats.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -""" - -from functools import partial - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import vmap - -from ..diffmah_kernels import ( - DiffmahParams, - DiffmahUParams, - get_bounded_mah_params, - get_unbounded_mah_params, - mah_singlehalo, -) -from .covariance_kernels import _get_diffmahpop_cov -from .diffmahpop_params_censat import get_component_model_params -from .early_index_pop import _pred_early_index_kern -from .late_index_pop import _pred_late_index_kern -from .logm0_kernels.logm0_pop import _pred_logm0_kern -from .logtc_pop import _pred_logtc_kern -from .t_peak_kernels.tp_pdf_sats import mc_tpeak_singlesat - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 -NH_PER_M0BIN = 200 - - -@jjit -def mc_mean_diffmah_params(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - model_params = get_component_model_params(diffmahpop_params) - ( - ftpt0_cens_params, - tp_pdf_cens_params, - tp_pdf_sats_params, - logm0_params, - logtc_params, - early_index_params, - late_index_params, - cov_params, - ) = model_params - - tpc_key, ran_key = jran.split(ran_key, 2) - - lgm_obs = lgm_obs - t_obs = t_obs - args = tp_pdf_sats_params, tpc_key, lgm_obs, t_obs - t_peak_sats = mc_tpeak_singlesat(*args) - - logm0_tp = _pred_logm0_kern(logm0_params, lgm_obs, t_obs, t_peak_sats) - - logtc_tp = _pred_logtc_kern(logtc_params, lgm_obs, t_obs, t_peak_sats) - - early_index_tp = _pred_early_index_kern( - early_index_params, lgm_obs, t_obs, t_peak_sats - ) - - late_index_tp = _pred_late_index_kern(late_index_params, lgm_obs) - - dmah_sats = DiffmahParams(logm0_tp, logtc_tp, early_index_tp, late_index_tp) - - return dmah_sats, t_peak_sats - - -@jjit -def mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - dmah_sats, t_peak_sats = mc_mean_diffmah_params( - diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 - ) - u_dmah_sats = get_unbounded_mah_params(dmah_sats) - - cov = _get_diffmahpop_cov(diffmahpop_params, lgm_obs) - - ran_key, tpt0_key, tp_key = jran.split(ran_key, 3) - - ran_diffmah_u_params_tp = jran.multivariate_normal( - tp_key, jnp.array(u_dmah_sats), cov, shape=() - ) - ran_diffmah_u_params = DiffmahUParams(*ran_diffmah_u_params_tp) - - mah_params = get_bounded_mah_params(ran_diffmah_u_params) - return mah_params, t_peak_sats - - -_A = (None, 0, 0, 0, None) -_mc_diffmah_params_vmap_kern = jjit(vmap(mc_diffmah_params_singlesat, in_axes=_A)) - - -@jjit -def mc_diffmah_params_satpop(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0): - ran_keys = jran.split(ran_key, lgm_obs.size) - return _mc_diffmah_params_vmap_kern( - diffmahpop_params, lgm_obs, t_obs, ran_keys, lgt0 - ) - - -@jjit -def _mc_diffmah_singlesat(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - _res = mc_diffmah_params_singlesat(diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0) - mah_params, t_peak_sats = _res - dmhdt_sats, log_mah_sats = mah_singlehalo(mah_params, tarr, t_peak_sats, lgt0) - return mah_params, t_peak_sats, dmhdt_sats, log_mah_sats - - -_V = (None, None, 0, 0, 0, None) -_mc_diffmah_singlesat_vmap_kern = jjit(vmap(_mc_diffmah_singlesat, in_axes=_V)) - - -@partial(jjit, static_argnames=["n_mc"]) -def _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, n_mc=NH_PER_M0BIN -): - zz = jnp.zeros(n_mc) - ran_keys = jran.split(ran_key, n_mc) - return _mc_diffmah_singlesat_vmap_kern( - diffmahpop_params, tarr, lgm_obs + zz, t_obs + zz, ran_keys, lgt0 - ) - - -@jjit -def predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 -): - _res = _mc_diffmah_halo_sample( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mah_params, t_peak_sats, dmhdt, log_mah = _res - - mean_log_mah = jnp.mean(log_mah, axis=0) - std_log_mah = jnp.std(log_mah, axis=0) - frac_peaked = jnp.mean(dmhdt == 0, axis=0) - - return mean_log_mah, std_log_mah, frac_peaked diff --git a/diffmah/diffmahpop_kernels/mean_param_fitting_kernels.py b/diffmah/diffmahpop_kernels/mean_param_fitting_kernels.py deleted file mode 100644 index 750a333..0000000 --- a/diffmah/diffmahpop_kernels/mean_param_fitting_kernels.py +++ /dev/null @@ -1,125 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import value_and_grad, vmap - -from ..diffmah_kernels import DiffmahParams, mah_halopop, mah_singlehalo -from .diffmahpop_params import ( - DEFAULT_DIFFMAHPOP_U_PARAMS, - get_diffmahpop_params_from_u_params, -) -from .mc_diffmahpop_kernels import mc_mean_diffmah_params - -N_TP_PER_HALO = 40 -T_OBS_FIT_MIN = 0.5 - - -@jjit -def mc_tp_avg_mah_singlecen(diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0): - dmah_tpt0, dmah_tp, t_peak, ftpt0, __ = mc_mean_diffmah_params( - diffmahpop_params, lgm_obs, t_obs, ran_key, lgt0 - ) - ZZ = jnp.zeros_like(t_peak) - tpt0 = ZZ + 10**lgt0 - __, log_mah_tpt0 = mah_halopop(dmah_tpt0, tarr, tpt0, lgt0) - __, log_mah_tp = mah_halopop(dmah_tp, tarr, t_peak, lgt0) - - avg_log_mah_tpt0 = jnp.mean(log_mah_tpt0, axis=0) - avg_log_mah_tp = jnp.mean(log_mah_tp, axis=0) - avg_log_mah = ftpt0 * avg_log_mah_tpt0 + (1 - ftpt0) * avg_log_mah_tp - return avg_log_mah - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_scalar_kern( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, avg_log_mah_target -): - avg_log_mah_pred = mc_tp_avg_mah_singlecen( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - loss = _mse(avg_log_mah_pred, avg_log_mah_target) - return loss - - -_A = (None, 0, 0, 0, 0, None, 0) -_loss_vmap_kern = jjit(vmap(_loss_scalar_kern, in_axes=_A)) - - -@jjit -def multiloss_vmap( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, avg_log_mah_target -): - losses = _loss_vmap_kern( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, avg_log_mah_target - ) - return jnp.sum(losses) - - -multiloss_and_grads_vmap = jjit(value_and_grad(multiloss_vmap)) - - -@jjit -def _loss_scalar_kern_subset_u_params( - diffmahpop_subset_u_params, tarr, lgm_obs, t_obs, ran_key, lgt0, avg_log_mah_target -): - diffmahpop_u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._replace( - **diffmahpop_subset_u_params._asdict() - ) - diffmahpop_params = get_diffmahpop_params_from_u_params(diffmahpop_u_params) - args = diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0, avg_log_mah_target - return _loss_scalar_kern(*args) - - -_A = (None, 0, 0, 0, 0, None, 0) -_loss_vmap_kern_subset_u_params = jjit( - vmap(_loss_scalar_kern_subset_u_params, in_axes=_A) -) - - -@jjit -def multiloss_vmap_subset_u_params(diffmahpop_subset_u_params, loss_data): - tarr, lgm_obs, t_obs, ran_key, lgt0, avg_log_mah_target = loss_data - losses = _loss_vmap_kern_subset_u_params( - diffmahpop_subset_u_params, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - avg_log_mah_target, - ) - return jnp.sum(losses) - - -multiloss_and_grads_vmap_subset_u_params = jjit( - value_and_grad(multiloss_vmap_subset_u_params) -) - - -def get_loss_data_singlehalo(mah_data, ih, lgt0, nt=50): - mah_params_ih = DiffmahParams( - *[mah_data[key][ih] for key in ("logm0", "logtc", "early_index", "late_index")] - ) - t_obs = mah_data["t_obs"][ih] - t_target = jnp.linspace(T_OBS_FIT_MIN, t_obs, nt) - args = mah_params_ih, t_target, mah_data["t_peak"][ih], lgt0 - avg_log_mah_target = mah_singlehalo(*args)[1] - ran_key = jran.key(ih) - loss_data = ( - t_target, - mah_data["logmp_at_z"][ih], - mah_data["t_obs"][ih], - ran_key, - lgt0, - avg_log_mah_target, - ) - return loss_data diff --git a/diffmah/diffmahpop_kernels/monocens_fithelp.py b/diffmah/diffmahpop_kernels/monocens_fithelp.py deleted file mode 100644 index 9759f76..0000000 --- a/diffmah/diffmahpop_kernels/monocens_fithelp.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import value_and_grad, vmap - -from . import mc_diffmahpop_kernels_monocens as mcdk - -T_OBS_FIT_MIN = 0.5 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_mah_moments_singlebin( - diffmahpop_params, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - target_mean_log_mah, - target_std_log_mah, - target_frac_peaked, -): - _preds = mcdk.predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mean_log_mah, std_log_mah, frac_peaked = _preds - loss = _mse(mean_log_mah, target_mean_log_mah) - loss = loss + _mse(std_log_mah, target_std_log_mah) - # loss = loss + _mse(frac_peaked, target_frac_peaked) - return loss - - -_U = (None, 0, 0, 0, 0, None, 0, 0, 0) -_loss_mah_moments_multibin_vmap = jjit(vmap(_loss_mah_moments_singlebin, in_axes=_U)) - - -@jjit -def _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, -): - ran_keys = jran.split(ran_key, tarr_matrix.shape[0]) - return _loss_mah_moments_multibin_vmap( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_keys, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, - ) - - -@jjit -def loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, -): - losses = _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, - ) - return jnp.mean(losses) - - -loss_and_grads_mah_moments_multibin = value_and_grad(loss_mah_moments_multibin) diff --git a/diffmah/diffmahpop_kernels/monocens_fixed_tpeak_fithelp.py b/diffmah/diffmahpop_kernels/monocens_fixed_tpeak_fithelp.py deleted file mode 100644 index 4b33d5c..0000000 --- a/diffmah/diffmahpop_kernels/monocens_fixed_tpeak_fithelp.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -""" - -from collections import namedtuple - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import value_and_grad, vmap - -from . import mc_diffmahpop_kernels_monocens_fixed_tpeak as mcdk -from .diffmahpop_params_monocensat import ( - DEFAULT_DIFFMAHPOP_U_PARAMS, - get_diffmahpop_params_from_u_params, -) - -T_OBS_FIT_MIN = 0.5 - - -def get_varied_u_params(): - fixed_u_pnames = ( - "u_cen_tp_x0_ylo", - "u_cen_tp_x0_yhi", - "u_utp_loc_lgm_ylo_t0", - "u_utp_loc_lgm_ylo_early", - "u_utp_loc_lgm_ylo_late", - "u_utp_loc_lgm_x0", - "u_utp_scale_lgm_ylo_t0", - "u_utp_scale_lgm_ylo_early", - "u_utp_scale_lgm_ylo_late", - ) - u_pdict = dict() - gen = zip(DEFAULT_DIFFMAHPOP_U_PARAMS._fields, DEFAULT_DIFFMAHPOP_U_PARAMS) - for key, val in gen: - if key not in fixed_u_pnames: - u_pdict[key] = val - VariedUParams = namedtuple("VariedUParams", u_pdict.keys()) - varied_u_params = VariedUParams(**u_pdict) - return varied_u_params - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_mah_moments_singlebin_u_params( - varied_u_params, - tarr, - lgm_obs, - t_obs, - t_peak_sample, - ran_key, - lgt0, - target_mean_log_mah, - target_std_log_mah, - target_frac_peaked, -): - u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._replace(**varied_u_params._asdict()) - diffmahpop_params = get_diffmahpop_params_from_u_params(u_params) - _preds = mcdk.predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, t_peak_sample, ran_key, lgt0 - ) - mean_log_mah, std_log_mah, frac_peaked = _preds - loss = _mse(mean_log_mah, target_mean_log_mah) - loss = loss + _mse(std_log_mah, target_std_log_mah) - # loss = loss + _mse(frac_peaked, target_frac_peaked) - return loss - - -_U = (None, 0, 0, 0, 0, 0, None, 0, 0, 0) -_loss_mah_moments_multibin_vmap = jjit( - vmap(_loss_mah_moments_singlebin_u_params, in_axes=_U) -) - - -@jjit -def _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - t_peak_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, -): - ran_keys = jran.split(ran_key, tarr_matrix.shape[0]) - return _loss_mah_moments_multibin_vmap( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - t_peak_arr, - ran_keys, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, - ) - - -@jjit -def loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - t_peak_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, -): - losses = _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - t_peak_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, - ) - return jnp.mean(losses) - - -_loss_and_grads_mah_moments_multibin = value_and_grad(loss_mah_moments_multibin) - - -@jjit -def loss_and_grads_mah_moments_multibin(params, loss_data): - return _loss_and_grads_mah_moments_multibin(params, *loss_data) diff --git a/diffmah/diffmahpop_kernels/monocensat_fithelp.py b/diffmah/diffmahpop_kernels/monocensat_fithelp.py deleted file mode 100644 index bbbb72c..0000000 --- a/diffmah/diffmahpop_kernels/monocensat_fithelp.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import random as jran -from jax import value_and_grad - -from . import monocens_fithelp, monosats_fithelp -from .diffmahpop_params_monocensat import ( - DEFAULT_DIFFMAHPOP_PARAMS, - DEFAULT_DIFFMAHPOP_U_PARAMS, - get_diffmahpop_params_from_u_params, -) - - -@jjit -def loss_mah_moments_multibin_censat( - varied_diffmahpop_params, - tarr_matrix_cens, - lgm_obs_arr_cens, - t_obs_arr_cens, - tarr_matrix_sats, - lgm_obs_arr_sats, - t_obs_arr_sats, - ran_key, - lgt0, - target_mean_log_mahs_cens, - target_std_log_mahs_cens, - target_frac_peaked_cens, - target_mean_log_mahs_sats, - target_std_log_mahs_sats, - target_frac_peaked_sats, -): - diffmahpop_params = DEFAULT_DIFFMAHPOP_PARAMS._replace( - **varied_diffmahpop_params._asdict() - ) - ran_key_cens, ran_key_sats = jran.split(ran_key, 2) - loss_cens = monocens_fithelp.loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix_cens, - lgm_obs_arr_cens, - t_obs_arr_cens, - ran_key_cens, - lgt0, - target_mean_log_mahs_cens, - target_std_log_mahs_cens, - target_frac_peaked_cens, - ) - - loss_sats = monosats_fithelp.loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix_sats, - lgm_obs_arr_sats, - t_obs_arr_sats, - ran_key_sats, - lgt0, - target_mean_log_mahs_sats, - target_std_log_mahs_sats, - target_frac_peaked_sats, - ) - return loss_cens + loss_sats - - -loss_and_grads_mah_moments_multibin_censat = jjit( - value_and_grad(loss_mah_moments_multibin_censat) -) - - -@jjit -def loss_mah_moments_multibin_censat_u_params(u_params, loss_data): - u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._replace(**u_params._asdict()) - params = get_diffmahpop_params_from_u_params(u_params) - return loss_mah_moments_multibin_censat(params, *loss_data) - - -loss_and_grads_mah_moments_multibin_censat_u_params = jjit( - value_and_grad(loss_mah_moments_multibin_censat_u_params) -) diff --git a/diffmah/diffmahpop_kernels/monosats_fithelp.py b/diffmah/diffmahpop_kernels/monosats_fithelp.py deleted file mode 100644 index 0a6ba1f..0000000 --- a/diffmah/diffmahpop_kernels/monosats_fithelp.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import value_and_grad, vmap - -from . import mc_diffmahpop_kernels_monosats as mcs - -T_OBS_FIT_MIN = 0.5 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_mah_moments_singlebin( - diffmahpop_params, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - target_mean_log_mah, - target_std_log_mah, - target_frac_peaked, -): - _preds = mcs.predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mean_log_mah, std_log_mah, frac_peaked = _preds - loss = _mse(mean_log_mah, target_mean_log_mah) - loss = loss + _mse(std_log_mah, target_std_log_mah) - # loss = loss + _mse(frac_peaked, target_frac_peaked) - return loss - - -_U = (None, 0, 0, 0, 0, None, 0, 0, 0) -_loss_mah_moments_multibin_vmap = jjit(vmap(_loss_mah_moments_singlebin, in_axes=_U)) - - -@jjit -def _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, -): - ran_keys = jran.split(ran_key, tarr_matrix.shape[0]) - return _loss_mah_moments_multibin_vmap( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_keys, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, - ) - - -@jjit -def loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, -): - losses = _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - target_frac_peaked, - ) - return jnp.mean(losses) - - -loss_and_grads_mah_moments_multibin = value_and_grad(loss_mah_moments_multibin) diff --git a/diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_cens.py b/diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_cens.py deleted file mode 100644 index 4f2fef9..0000000 --- a/diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_cens.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from .. import tp_pdf_cens as tpc - -TOL = 1e-3 - - -def test_t_peak_cens_fitter(): - t_0 = 13.8 - x_target = np.linspace(1, t_0, 100) - pdf_target = np.ones_like(x_target) - args = (x_target, pdf_target, t_0) - _res = tpc.t_peak_cens_fitter(*args) - p_best, loss_best, fit_terminates, __ = _res - assert np.all(np.isfinite(p_best)) - assert loss_best > 0 - assert fit_terminates - - -def test_param_u_param_names_propagate_properly(): - gen = zip( - tpc.DEFAULT_TPCENS_U_PARAMS._fields, - tpc.DEFAULT_TPCENS_PARAMS._fields, - ) - for u_key, key in gen: - assert u_key[:2] == "u_" - assert u_key[2:] == key - - inferred_default_params = tpc.get_bounded_tp_cens_params( - tpc.DEFAULT_TPCENS_U_PARAMS - ) - assert set(inferred_default_params._fields) == set( - tpc.DEFAULT_TPCENS_PARAMS._fields - ) - - inferred_default_u_params = tpc.get_unbounded_tp_cens_params( - tpc.DEFAULT_TPCENS_PARAMS - ) - assert set(inferred_default_u_params._fields) == set( - tpc.DEFAULT_TPCENS_U_PARAMS._fields - ) - - -def test_get_bounded_params_fails_when_passing_params(): - try: - tpc.get_bounded_tp_pdf_params(tpc.DEFAULT_TPCENS_PARAMS) - raise NameError("get_bounded_tp_pdf_params should not accept params") - except AttributeError: - pass - - -def test_get_unbounded_params_fails_when_passing_u_params(): - try: - tpc.get_unbounded_tp_pdf_params(tpc.DEFAULT_TPCENS_U_PARAMS) - raise NameError("get_unbounded_tp_pdf_params should not accept u_params") - except AttributeError: - pass - - -def test_param_u_param_inversion(): - assert np.allclose( - tpc.DEFAULT_TPCENS_PARAMS, - tpc.get_bounded_tp_cens_params(tpc.DEFAULT_TPCENS_U_PARAMS), - rtol=TOL, - ) - - inferred_default_params = tpc.get_bounded_tp_cens_params( - tpc.get_unbounded_tp_cens_params(tpc.DEFAULT_TPCENS_PARAMS) - ) - assert np.allclose(tpc.DEFAULT_TPCENS_PARAMS, inferred_default_params, rtol=TOL) - - -def test_default_params_are_in_bounds(): - for key in tpc.DEFAULT_TPCENS_PARAMS._fields: - val = getattr(tpc.DEFAULT_TPCENS_PARAMS, key) - bound = getattr(tpc.TPCENS_PBOUNDS, key) - assert bound[0] < val < bound[1] - - -def test_mc_tpeak_cens(): - t_0 = 13.0 - n_gals = int(1e4) - ran_key = jran.key(0) - ran_key, m_key, t_key = jran.split(ran_key, 3) - lgm_obs = jran.uniform(m_key, minval=10.0, maxval=15.0, shape=(n_gals,)) - t_obs = jran.uniform(t_key, minval=2.0, maxval=t_0, shape=(n_gals,)) - args = tpc.DEFAULT_TPCENS_PARAMS, ran_key, lgm_obs, t_obs, t_0 - t_peak_mc_sample = tpc.mc_tpeak_cens(*args) - assert t_peak_mc_sample.shape == (n_gals,) - assert np.all(np.isfinite(t_peak_mc_sample)) - assert np.all(t_peak_mc_sample > 0) - assert np.all(t_peak_mc_sample <= t_0) diff --git a/diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_monocens.py b/diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_monocens.py deleted file mode 100644 index bed70b5..0000000 --- a/diffmah/diffmahpop_kernels/t_peak_kernels/tests/test_tp_pdf_monocens.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from .. import tp_pdf_monocens as tpc - -TOL = 1e-3 - - -def test_param_u_param_names_propagate_properly(): - gen = zip( - tpc.DEFAULT_TPCENS_U_PARAMS._fields, - tpc.DEFAULT_TPCENS_PARAMS._fields, - ) - for u_key, key in gen: - assert u_key[:2] == "u_" - assert u_key[2:] == key - - inferred_default_params = tpc.get_bounded_tp_cens_params( - tpc.DEFAULT_TPCENS_U_PARAMS - ) - assert set(inferred_default_params._fields) == set( - tpc.DEFAULT_TPCENS_PARAMS._fields - ) - - inferred_default_u_params = tpc.get_unbounded_tp_cens_params( - tpc.DEFAULT_TPCENS_PARAMS - ) - assert set(inferred_default_u_params._fields) == set( - tpc.DEFAULT_TPCENS_U_PARAMS._fields - ) - - -def test_get_bounded_params_fails_when_passing_params(): - try: - tpc.get_bounded_tp_pdf_params(tpc.DEFAULT_TPCENS_PARAMS) - raise NameError("get_bounded_tp_pdf_params should not accept params") - except AttributeError: - pass - - -def test_get_unbounded_params_fails_when_passing_u_params(): - try: - tpc.get_unbounded_tp_pdf_params(tpc.DEFAULT_TPCENS_U_PARAMS) - raise NameError("get_unbounded_tp_pdf_params should not accept u_params") - except AttributeError: - pass - - -def test_param_u_param_inversion(): - assert np.allclose( - tpc.DEFAULT_TPCENS_PARAMS, - tpc.get_bounded_tp_cens_params(tpc.DEFAULT_TPCENS_U_PARAMS), - rtol=TOL, - ) - - inferred_default_params = tpc.get_bounded_tp_cens_params( - tpc.get_unbounded_tp_cens_params(tpc.DEFAULT_TPCENS_PARAMS) - ) - assert np.allclose(tpc.DEFAULT_TPCENS_PARAMS, inferred_default_params, rtol=TOL) - - -def test_default_params_are_in_bounds(): - for key in tpc.DEFAULT_TPCENS_PARAMS._fields: - val = getattr(tpc.DEFAULT_TPCENS_PARAMS, key) - bound = getattr(tpc.TPCENS_PBOUNDS, key) - assert bound[0] < val < bound[1] - - -def test_mc_tpeak_singlecen(): - t_0 = 13.0 - ran_key = jran.key(0) - ran_key, m_key = jran.split(ran_key, 2) - lgm_obs = jran.uniform(m_key, minval=10.0, maxval=15.0, shape=()) - args = tpc.DEFAULT_TPCENS_PARAMS, lgm_obs, ran_key, t_0 - t_peak_mc_sample = tpc.mc_tpeak_singlecen(*args) - assert t_peak_mc_sample.shape == () - assert np.all(np.isfinite(t_peak_mc_sample)) - assert np.all(t_peak_mc_sample > 0) - assert np.all(t_peak_mc_sample <= t_0) - - -def test_mc_t_peak_cenpop(): - t_0 = 13.0 - ran_key = jran.key(0) - ran_key, m_key = jran.split(ran_key, 2) - n_gals = int(1e4) - lgm_obs = jran.uniform(m_key, minval=10.0, maxval=15.0, shape=(n_gals,)) - args = tpc.DEFAULT_TPCENS_PARAMS, lgm_obs, ran_key, t_0 - t_peak_mc_sample = tpc.mc_t_peak_cenpop(*args) - assert t_peak_mc_sample.shape == (n_gals,) - assert np.all(np.isfinite(t_peak_mc_sample)) - assert np.all(t_peak_mc_sample <= t_0) - assert np.all(t_peak_mc_sample > 0), t_peak_mc_sample.min() diff --git a/diffmah/diffmahpop_kernels/t_peak_kernels/tp_pdf_monocens.py b/diffmah/diffmahpop_kernels/t_peak_kernels/tp_pdf_monocens.py deleted file mode 100644 index 003c66b..0000000 --- a/diffmah/diffmahpop_kernels/t_peak_kernels/tp_pdf_monocens.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Sigmoid model of the CDF P_cen(t 0) - assert np.all(t_peak <= t_0) - assert np.any(t_peak < t_0) - - assert ftpt0.shape == () - assert np.all(ftpt0 >= 0) - assert np.all(ftpt0 <= 1) - assert np.any(ftpt0 > 0) - assert np.any(ftpt0 < 1) - - for p, bound in zip(dmah_tpt0, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - for p, bound in zip(dmah_tp, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - _res = mcdpk.mc_mean_diffmah_params(*args) - ran_diffmah_params_tpt0, ran_diffmah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_singlecen(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.mc_diffmah_params_singlecen(*args) - mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - assert np.all(np.isfinite(mah_params_tpt0.logtc)) - assert np.all(np.isfinite(mah_params_tp.logtc)) - - -def test_predict_mah_moments_singlebin(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - mean_log_mah, std_log_mah = mcdpk.predict_mah_moments_singlebin(*args) - assert np.all(np.isfinite(mean_log_mah)) - assert np.all(np.isfinite(std_log_mah)) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk._mc_diffmah_halo_sample(*args) - ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) = _res - assert np.all(np.isfinite(mah_params_tpt0)) - assert np.all(np.isfinite(mah_params_tp)) - - assert np.all(np.isfinite(t_peak)) - assert np.all(t_peak > 0.0) - assert np.all(t_peak <= t_0) - - assert np.all(np.isfinite(ftpt0)) - assert np.all(ftpt0 > 0.0) - assert np.all(ftpt0 < 1.0) - - assert np.all(np.isfinite(log_mah_tpt0)) - assert np.all(np.isfinite(log_mah_tp)) diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_cens.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_cens.py deleted file mode 100644 index 56763cf..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_cens.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ...diffmah_kernels import MAH_PBOUNDS -from .. import mc_diffmahpop_kernels_cens as mcdpk -from ..diffmahpop_params_censat import DEFAULT_DIFFMAHPOP_PARAMS - - -def test_mc_mean_diffmah_params_are_always_in_bounds(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - dmah_tpt0, dmah_tp, t_peak, ftpt0, mc_tpt0 = mcdpk.mc_mean_diffmah_params( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - ) - assert np.all(t_peak > 0) - assert np.all(t_peak <= t_0) - assert np.any(t_peak < t_0) - - assert ftpt0.shape == () - assert np.all(ftpt0 >= 0) - assert np.all(ftpt0 <= 1) - assert np.any(ftpt0 > 0) - assert np.any(ftpt0 < 1) - - for p, bound in zip(dmah_tpt0, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - for p, bound in zip(dmah_tp, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - _res = mcdpk.mc_mean_diffmah_params(*args) - ran_diffmah_params_tpt0, ran_diffmah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_singlecen(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.mc_diffmah_params_singlecen(*args) - mah_params_tpt0, mah_params_tp, t_peak, ftpt0, mc_tpt0 = _res - assert np.all(np.isfinite(mah_params_tpt0.logtc)) - assert np.all(np.isfinite(mah_params_tp.logtc)) - - -def test_predict_mah_moments_singlebin(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - mean_log_mah, std_log_mah, f_peaked = mcdpk.predict_mah_moments_singlebin(*args) - assert np.all(np.isfinite(mean_log_mah)) - assert np.all(np.isfinite(std_log_mah)) - assert np.all(np.isfinite(f_peaked)) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk._mc_diffmah_halo_sample(*args) - ( - mah_params_tpt0, - mah_params_tp, - t_peak, - ftpt0, - mc_tpt0, - dmhdt_tpt0, - log_mah_tpt0, - dmhdt_tp, - log_mah_tp, - ) = _res - assert np.all(np.isfinite(mah_params_tpt0)) - assert np.all(np.isfinite(mah_params_tp)) - - assert np.all(np.isfinite(t_peak)) - assert np.all(t_peak > 0.0) - assert np.all(t_peak <= t_0) - - assert np.all(np.isfinite(ftpt0)) - assert np.all(ftpt0 > 0.0) - assert np.all(ftpt0 < 1.0) - - assert np.all(np.isfinite(log_mah_tpt0)) - assert np.all(np.isfinite(log_mah_tp)) diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_censat.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_censat.py deleted file mode 100644 index 2f7e1d4..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_censat.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ...diffmah_kernels import MAH_PBOUNDS -from .. import mc_diffmahpop_kernels_censat as mcdpk -from ..diffmahpop_params_censat import DEFAULT_DIFFMAHPOP_PARAMS - - -def test_mc_mean_diffmah_params_are_always_in_bounds(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - _res = mcdpk.mc_mean_diffmah_params( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - ) - ( - dmah_tpt0_cens, - dmah_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - t_peak_sats, - dmah_sats, - ) = _res - assert np.all(t_peak_cens > 0) - assert np.all(t_peak_cens <= t_0) - assert np.any(t_peak_cens < t_0) - - assert np.all(t_peak_sats > 0) - assert np.all(t_peak_sats < t_obs) - - assert frac_tpt0_cens.shape == () - assert np.all(frac_tpt0_cens >= 0) - assert np.all(frac_tpt0_cens <= 1) - assert np.any(frac_tpt0_cens > 0) - assert np.any(frac_tpt0_cens < 1) - - for p, bound in zip(dmah_tpt0_cens, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - for p, bound in zip(dmah_tp_cens, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - for p, bound in zip(dmah_sats, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - _res = mcdpk.mc_mean_diffmah_params(*args) - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_single_censat(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.mc_diffmah_params_single_censat(*args) - ( - mah_params_tpt0_cens, - mah_params_tp_cens, - t_peak_cens, - ftpt0_cens, - mc_tpt0_cens, - t_peak_sats, - mah_params_sats, - ) = _res - assert np.all(np.isfinite(mah_params_tpt0_cens.logtc)) - assert np.all(np.isfinite(mah_params_tp_cens.logtc)) - assert np.all(np.isfinite(mah_params_sats.logtc)) - - -def test_predict_mah_moments_singlebin_censat(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.predict_mah_moments_singlebin_censat(*args) - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk._mc_diffmah_halo_sample_censat(*args) - ( - mah_params_tpt0_cens, - mah_params_tp_cens, - t_peak_cens, - frac_tpt0_cens, - mc_tpt0_cens, - dmhdt_tpt0_cens, - log_mah_tpt0_cens, - dmhdt_tp_cens, - log_mah_tp_cens, - dmhdt_sats, - log_mah_sats, - ) = _res - assert np.all(np.isfinite(mah_params_tpt0_cens)) - assert np.all(np.isfinite(mah_params_tp_cens)) - - assert np.all(np.isfinite(t_peak_cens)) - assert np.all(t_peak_cens > 0.0) - assert np.all(t_peak_cens <= t_0) - - assert np.all(np.isfinite(frac_tpt0_cens)) - assert np.all(frac_tpt0_cens > 0.0) - assert np.all(frac_tpt0_cens < 1.0) - - assert np.all(np.isfinite(log_mah_tpt0_cens)) - assert np.all(np.isfinite(log_mah_tp_cens)) diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens.py deleted file mode 100644 index 1967684..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ...diffmah_kernels import MAH_PBOUNDS -from .. import mc_diffmahpop_kernels_monocens as mcdpk -from ..diffmahpop_params_monocensat import DEFAULT_DIFFMAHPOP_PARAMS - -EPS = 1e-4 - - -def test_mc_mean_diffmah_params_are_always_in_bounds(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - mah_params, t_peak = mcdpk.mc_mean_diffmah_params( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - ) - assert np.all(t_peak > 0) - assert np.all(t_peak <= t_0 + EPS) - - for p, bound in zip(mah_params, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - for p, bound in zip(mah_params, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - _res = mcdpk.mc_mean_diffmah_params(*args) - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_singlecen(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.mc_diffmah_params_singlecen(*args) - mah_params, t_peak = _res - assert np.all(np.isfinite(mah_params.logtc)) - - -def test_predict_mah_moments_singlebin(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - mean_log_mah, std_log_mah, f_peaked = mcdpk.predict_mah_moments_singlebin(*args) - assert np.all(np.isfinite(mean_log_mah)) - assert np.all(np.isfinite(std_log_mah)) - assert np.all(np.isfinite(f_peaked)) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk._mc_diffmah_halo_sample(*args) - (mah_params, t_peak, dmhdt, log_mah) = _res - assert np.all(np.isfinite(mah_params)) - - assert np.all(np.isfinite(t_peak)) - assert np.all(t_peak > 0.0) - assert np.all(t_peak <= t_0) - - assert np.all(np.isfinite(log_mah)) - assert np.all(np.isfinite(dmhdt)) diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens_fixed_tpeak.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens_fixed_tpeak.py deleted file mode 100644 index 3436832..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monocens_fixed_tpeak.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ...diffmah_kernels import MAH_PBOUNDS -from .. import mc_diffmahpop_kernels_monocens_fixed_tpeak as mcdpk -from ..diffmahpop_params_monocensat import DEFAULT_DIFFMAHPOP_PARAMS - -EPS = 1e-4 - - -def test_mc_mean_diffmah_params_are_always_in_bounds(): - t_obs = 10.0 - lgmarr = np.linspace(10, 16, 20) - t_peak = 10.0 - for lgm_obs in lgmarr: - mah_params = mcdpk.mc_mean_diffmah_params( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, t_peak - ) - for p, bound in zip(mah_params, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - for p, bound in zip(mah_params, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - t_peak = 8.0 - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, t_peak - _res = mcdpk.mc_mean_diffmah_params(*args) - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_singlecen(): - ran_key = jran.key(0) - t_obs = 10.0 - t_peak = 8.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, t_peak, ran_key) - mah_params = mcdpk.mc_diffmah_params_singlecen(*args) - assert np.all(np.isfinite(mah_params.logtc)) - - -def test_predict_mah_moments_singlebin(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - n_sample = 20 - lgmarr = np.linspace(10, 15, n_sample) - t_peak_sample = np.linspace(3, 10, n_sample) - for lgm_obs in lgmarr: - args = ( - DEFAULT_DIFFMAHPOP_PARAMS, - tarr, - lgm_obs, - t_obs, - t_peak_sample, - ran_key, - lgt0, - ) - mean_log_mah, std_log_mah, f_peaked = mcdpk.predict_mah_moments_singlebin(*args) - assert np.all(np.isfinite(mean_log_mah)) - assert np.all(np.isfinite(std_log_mah)) - assert np.all(np.isfinite(f_peaked)) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - n_t = 100 - tarr = np.linspace(0.1, t_obs, n_t) - n_test = 20 - lgmarr = np.linspace(10, 15, n_test) - n_sample = 50 - t_peak_sample = np.linspace(3, 10, n_sample) - for lgm_obs in lgmarr: - args = ( - DEFAULT_DIFFMAHPOP_PARAMS, - tarr, - lgm_obs, - t_obs, - t_peak_sample, - ran_key, - lgt0, - ) - _res = mcdpk._mc_diffmah_halo_sample(*args) - (mah_params, dmhdt, log_mah) = _res - assert log_mah.shape == (n_sample, n_t) - assert np.all(np.isfinite(mah_params)) - - assert np.all(np.isfinite(log_mah)) - assert np.all(np.isfinite(dmhdt)) diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monosats.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monosats.py deleted file mode 100644 index 7bd1581..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_monosats.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ...diffmah_kernels import MAH_PBOUNDS -from .. import mc_diffmahpop_kernels_monosats as mcdpk -from ..diffmahpop_params_monocensat import DEFAULT_DIFFMAHPOP_PARAMS - - -def test_mc_mean_diffmah_params_are_always_in_bounds(): - t_obs = 10.0 - ran_key = jran.key(0) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - dmah_sats, t_peak_sats = mcdpk.mc_mean_diffmah_params( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key - ) - assert np.all(t_peak_sats > 0) - assert np.all(t_peak_sats <= t_obs) - - for p, bound in zip(dmah_sats, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - ran_key = jran.key(0) - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key - _res = mcdpk.mc_mean_diffmah_params(*args) - dmah_sats, t_peak_sats = _res - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_singlesat(): - ran_key = jran.key(0) - t_obs = 10.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key) - _res = mcdpk.mc_diffmah_params_singlesat(*args) - mah_params, t_peak_sats = _res - assert np.all(np.isfinite(mah_params.logtc)) - assert np.all(np.isfinite(t_peak_sats)) - - -def test_predict_mah_moments_singlebin(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - mean_log_mah, std_log_mah, f_peaked = mcdpk.predict_mah_moments_singlebin(*args) - assert np.all(np.isfinite(mean_log_mah)) - assert np.all(np.isfinite(std_log_mah)) - assert np.all(np.isfinite(f_peaked)) - assert np.all(f_peaked >= 0) - assert np.all(f_peaked <= 1) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk._mc_diffmah_halo_sample(*args) - (mah_params, t_peak_sats, dmhdt_sats, log_mah_sats) = _res - assert np.all(np.isfinite(mah_params)) - - assert np.all(np.isfinite(t_peak_sats)) - assert np.all(t_peak_sats > 0.0) - assert np.all(t_peak_sats < t_obs) - - assert np.all(np.isfinite(log_mah_sats)) - assert np.all(np.isfinite(dmhdt_sats)) - - -def test_mc_diffmah_params_satpop(): - ran_key = jran.key(0) - lgm_obs, t_obs = 12.0, 10.0 - n_sats = 2_000 - ZZ = np.zeros(n_sats) - satpop = mcdpk.mc_diffmah_params_satpop( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs + ZZ, t_obs + ZZ, ran_key - ) - for x in satpop: - assert np.all(np.isfinite(x)) diff --git a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_sats.py b/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_sats.py deleted file mode 100644 index c449fcc..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_mc_diffmahpop_sats.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ...diffmah_kernels import MAH_PBOUNDS -from .. import mc_diffmahpop_kernels_sats as mcdpk -from ..diffmahpop_params_censat import DEFAULT_DIFFMAHPOP_PARAMS - - -def test_mc_mean_diffmah_params_are_always_in_bounds(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - dmah_sats, t_peak_sats = mcdpk.mc_mean_diffmah_params( - DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - ) - assert np.all(t_peak_sats > 0) - assert np.all(t_peak_sats <= t_obs) - - for p, bound in zip(dmah_sats, MAH_PBOUNDS): - assert np.all(bound[0] < p) - assert np.all(p < bound[1]) - - -def test_mc_mean_diffmah_params(): - t_obs = 10.0 - t_0 = 13.8 - ran_key = jran.key(0) - for lgm_obs in np.linspace(10, 16, 20): - args = DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, np.log10(t_0) - _res = mcdpk.mc_mean_diffmah_params(*args) - dmah_sats, t_peak_sats = _res - for _x in _res: - assert np.all(np.isfinite(_x)) - - -def test_mc_diffmah_params_singlesat(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk.mc_diffmah_params_singlesat(*args) - mah_params, t_peak_sats = _res - assert np.all(np.isfinite(mah_params.logtc)) - assert np.all(np.isfinite(t_peak_sats)) - - -def test_predict_mah_moments_singlebin(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 15, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - mean_log_mah, std_log_mah, f_peaked = mcdpk.predict_mah_moments_singlebin(*args) - assert np.all(np.isfinite(mean_log_mah)) - assert np.all(np.isfinite(std_log_mah)) - assert np.all(np.isfinite(f_peaked)) - assert np.all(f_peaked >= 0) - assert np.all(f_peaked <= 1) - - -def test_mc_diffmah_halo_sample(): - ran_key = jran.key(0) - t_0 = 13.0 - lgt0 = np.log10(t_0) - t_obs = 10.0 - tarr = np.linspace(0.1, t_obs, 100) - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - args = (DEFAULT_DIFFMAHPOP_PARAMS, tarr, lgm_obs, t_obs, ran_key, lgt0) - _res = mcdpk._mc_diffmah_halo_sample(*args) - (mah_params, t_peak_sats, dmhdt_sats, log_mah_sats) = _res - assert np.all(np.isfinite(mah_params)) - - assert np.all(np.isfinite(t_peak_sats)) - assert np.all(t_peak_sats > 0.0) - assert np.all(t_peak_sats < t_obs) - - assert np.all(np.isfinite(log_mah_sats)) - assert np.all(np.isfinite(dmhdt_sats)) diff --git a/diffmah/diffmahpop_kernels/tests/test_monocens_fithelp.py b/diffmah/diffmahpop_kernels/tests/test_monocens_fithelp.py deleted file mode 100644 index f3b6311..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_monocens_fithelp.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ... import diffmah_kernels -from .. import diffmahpop_params_monocensat as dpp -from .. import monocens_fithelp - - -def test_loss_grads(): - ran_key = jran.key(0) - t_obs = 10.0 - t_0 = 13.0 - lgt0 = np.log10(t_0) - tarr = np.linspace(0.1, t_obs, 100) - t_peak_target = t_0 - - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - mah_params = diffmah_kernels.DEFAULT_MAH_PARAMS._replace(logm0=lgm_obs) - __, mean_log_mah = diffmah_kernels.mah_singlehalo( - mah_params, tarr, t_peak_target, lgt0 - ) - std_log_mah = np.zeros_like(mean_log_mah) + 0.5 - target_frac_peaked = np.zeros_like(mean_log_mah) + 0.5 - args = ( - dpp.DEFAULT_DIFFMAHPOP_PARAMS, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - mean_log_mah, - std_log_mah, - target_frac_peaked, - ) - loss = monocens_fithelp._loss_mah_moments_singlebin(*args) - assert np.all(np.isfinite(loss)) diff --git a/diffmah/diffmahpop_kernels/tests/test_monocens_fixed_tpeak_fithelp.py b/diffmah/diffmahpop_kernels/tests/test_monocens_fixed_tpeak_fithelp.py deleted file mode 100644 index ddec752..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_monocens_fixed_tpeak_fithelp.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ... import diffmah_kernels -from .. import diffmahpop_params_monocensat as dpp -from .. import monocens_fixed_tpeak_fithelp - - -def test_loss_grads(): - ran_key = jran.key(0) - t_obs = 10.0 - t_0 = 13.0 - lgt0 = np.log10(t_0) - tarr = np.linspace(0.1, t_obs, 100) - - t_peak_target = t_0 - n_singlebin = 175 - t_peak_singlebin = np.linspace(3, 12, n_singlebin) - - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - mah_params = diffmah_kernels.DEFAULT_MAH_PARAMS._replace(logm0=lgm_obs) - __, mean_log_mah = diffmah_kernels.mah_singlehalo( - mah_params, tarr, t_peak_target, lgt0 - ) - std_log_mah = np.zeros_like(mean_log_mah) + 0.5 - target_frac_peaked = np.zeros_like(mean_log_mah) + 0.5 - args = ( - dpp.DEFAULT_DIFFMAHPOP_U_PARAMS, - tarr, - lgm_obs, - t_obs, - t_peak_singlebin, - ran_key, - lgt0, - mean_log_mah, - std_log_mah, - target_frac_peaked, - ) - loss = monocens_fixed_tpeak_fithelp._loss_mah_moments_singlebin_u_params(*args) - assert np.all(np.isfinite(loss)) diff --git a/diffmah/diffmahpop_kernels/tests/test_monocensat_fithelp.py b/diffmah/diffmahpop_kernels/tests/test_monocensat_fithelp.py deleted file mode 100644 index 1d07321..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_monocensat_fithelp.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ... import diffmah_kernels -from .. import diffmahpop_params_monocensat as dpp -from .. import monocensat_fithelp - - -def test_loss_grads(): - ran_key = jran.key(0) - t_obs = 10.0 - t_0 = 13.0 - lgt0 = np.log10(t_0) - tarr = np.linspace(0.1, t_obs, 100) - t_peak = t_0 - - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - mah_params = diffmah_kernels.DEFAULT_MAH_PARAMS._replace(logm0=lgm_obs) - __, mean_log_mah = diffmah_kernels.mah_singlehalo( - mah_params, tarr, t_peak, lgt0 - ) - std_log_mah = np.zeros_like(mean_log_mah) + 0.5 - target_frac_peaked = np.zeros_like(mean_log_mah) + 0.5 - args = ( - dpp.DEFAULT_DIFFMAHPOP_PARAMS, - tarr.reshape((1, -1)), - lgm_obs + np.zeros(1), - t_obs + np.zeros(1), - tarr.reshape((1, -1)), - lgm_obs + np.zeros(1), - t_obs + np.zeros(1), - ran_key, - lgt0, - mean_log_mah.reshape((1, -1)), - std_log_mah.reshape((1, -1)), - target_frac_peaked.reshape((1, -1)), - mean_log_mah.reshape((1, -1)), - std_log_mah.reshape((1, -1)), - target_frac_peaked.reshape((1, -1)), - ) - loss, grads = monocensat_fithelp.loss_and_grads_mah_moments_multibin_censat( - *args - ) - assert np.all(np.isfinite(loss)) - assert np.all(np.isfinite(grads)) diff --git a/diffmah/diffmahpop_kernels/tests/test_monosats_fithelp.py b/diffmah/diffmahpop_kernels/tests/test_monosats_fithelp.py deleted file mode 100644 index 0721dc9..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_monosats_fithelp.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ... import diffmah_kernels -from .. import diffmahpop_params_monocensat as dpp -from .. import monosats_fithelp - - -def test_loss_grads(): - ran_key = jran.key(0) - t_obs = 10.0 - t_0 = 13.0 - lgt0 = np.log10(t_0) - tarr = np.linspace(0.1, t_obs, 100) - t_peak = t_0 - - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - mah_params = diffmah_kernels.DEFAULT_MAH_PARAMS._replace(logm0=lgm_obs) - __, mean_log_mah = diffmah_kernels.mah_singlehalo( - mah_params, tarr, t_peak, lgt0 - ) - std_log_mah = np.zeros_like(mean_log_mah) + 0.5 - target_frac_peaked = np.zeros_like(mean_log_mah) + 0.5 - args = ( - dpp.DEFAULT_DIFFMAHPOP_PARAMS, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - mean_log_mah, - std_log_mah, - target_frac_peaked, - ) - loss = monosats_fithelp._loss_mah_moments_singlebin(*args) - assert np.all(np.isfinite(loss)) diff --git a/diffmah/diffmahpop_kernels/tests/test_variance_fithelp.py b/diffmah/diffmahpop_kernels/tests/test_variance_fithelp.py deleted file mode 100644 index 7435a0d..0000000 --- a/diffmah/diffmahpop_kernels/tests/test_variance_fithelp.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -""" - -import numpy as np -from jax import random as jran - -from ... import diffmah_kernels -from .. import diffmahpop_params as dpp -from .. import variance_fithelp - - -def test_loss_grads(): - ran_key = jran.key(0) - t_obs = 10.0 - t_0 = 13.0 - lgt0 = np.log10(t_0) - tarr = np.linspace(0.1, t_obs, 100) - t_peak = t_0 - - lgmarr = np.linspace(10, 16, 20) - for lgm_obs in lgmarr: - mah_params = diffmah_kernels.DEFAULT_MAH_PARAMS._replace(logm0=lgm_obs) - __, mean_log_mah = diffmah_kernels.mah_singlehalo( - mah_params, tarr, t_peak, lgt0 - ) - std_log_mah = np.zeros_like(mean_log_mah) + 0.5 - args = ( - dpp.DEFAULT_DIFFMAHPOP_PARAMS, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - mean_log_mah, - std_log_mah, - ) - loss = variance_fithelp._loss_mah_moments_singlebin(*args) - assert np.all(np.isfinite(loss)) diff --git a/diffmah/diffmahpop_kernels/variance_fithelp.py b/diffmah/diffmahpop_kernels/variance_fithelp.py deleted file mode 100644 index 5e2312b..0000000 --- a/diffmah/diffmahpop_kernels/variance_fithelp.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -""" - -from jax import jit as jjit -from jax import numpy as jnp -from jax import random as jran -from jax import value_and_grad, vmap - -from . import mc_diffmahpop_kernels as mcdk - -T_OBS_FIT_MIN = 0.5 - - -@jjit -def _mse(x, y): - d = y - x - return jnp.mean(d * d) - - -@jjit -def _loss_mah_moments_singlebin( - diffmahpop_params, - tarr, - lgm_obs, - t_obs, - ran_key, - lgt0, - target_mean_log_mah, - target_std_log_mah, -): - _preds = mcdk.predict_mah_moments_singlebin( - diffmahpop_params, tarr, lgm_obs, t_obs, ran_key, lgt0 - ) - mean_log_mah, std_log_mah = _preds - loss = _mse(mean_log_mah, target_mean_log_mah) - loss = loss + _mse(std_log_mah, target_std_log_mah) - return loss - - -_U = (None, 0, 0, 0, 0, None, 0, 0) -_loss_mah_moments_multibin_vmap = jjit(vmap(_loss_mah_moments_singlebin, in_axes=_U)) - - -@jjit -def _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, -): - ran_keys = jran.split(ran_key, tarr_matrix.shape[0]) - return _loss_mah_moments_multibin_vmap( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_keys, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - ) - - -@jjit -def loss_mah_moments_multibin( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, -): - losses = _loss_mah_moments_multibin_kern( - diffmahpop_params, - tarr_matrix, - lgm_obs_arr, - t_obs_arr, - ran_key, - lgt0, - target_mean_log_mahs, - target_std_log_mahs, - ) - return jnp.mean(losses) - - -loss_and_grads_mah_moments_multibin = value_and_grad(loss_mah_moments_multibin)