Skip to content

Commit

Permalink
Merge pull request #145 from ArgonneCPAC/satcal
Browse files Browse the repository at this point in the history
Recalibrate model with satellite-specific freedom for logm0
  • Loading branch information
aphearin authored Oct 16, 2024
2 parents 02d1a15 + 13d16bb commit b0fe84f
Show file tree
Hide file tree
Showing 44 changed files with 2,039 additions and 103 deletions.
2 changes: 1 addition & 1 deletion diffmah/diffmahpop_kernels/bimod_cens_fithelp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import value_and_grad, vmap

from . import bimod_censat_params
from . import mc_bimod_censat as mcdk
from . import mc_bimod_cens as mcdk

T_OBS_FIT_MIN = 0.5

Expand Down
78 changes: 78 additions & 0 deletions diffmah/diffmahpop_kernels/bimod_censat_fithelp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
"""

from jax import jit as jjit
from jax import random as jran
from jax import value_and_grad

from . import bimod_cens_fithelp, bimod_sats_fithelp
from .bimod_censat_params import (
DEFAULT_DIFFMAHPOP_PARAMS,
DEFAULT_DIFFMAHPOP_U_PARAMS,
get_diffmahpop_params_from_u_params,
)


@jjit
def loss_mah_moments_multibin_censat(
varied_diffmahpop_params,
tarr_matrix_cens,
lgm_obs_arr_cens,
t_obs_arr_cens,
tarr_matrix_sats,
lgm_obs_arr_sats,
t_obs_arr_sats,
ran_key,
lgt0,
target_mean_log_mahs_cens,
target_std_log_mahs_cens,
target_frac_peaked_cens,
target_mean_log_mahs_sats,
target_std_log_mahs_sats,
target_frac_peaked_sats,
):
diffmahpop_params = DEFAULT_DIFFMAHPOP_PARAMS._replace(
**varied_diffmahpop_params._asdict()
)
ran_key_cens, ran_key_sats = jran.split(ran_key, 2)
loss_cens = bimod_cens_fithelp.loss_mah_moments_multibin(
diffmahpop_params,
tarr_matrix_cens,
lgm_obs_arr_cens,
t_obs_arr_cens,
ran_key_cens,
lgt0,
target_mean_log_mahs_cens,
target_std_log_mahs_cens,
target_frac_peaked_cens,
)

loss_sats = bimod_sats_fithelp.loss_mah_moments_multibin(
diffmahpop_params,
tarr_matrix_sats,
lgm_obs_arr_sats,
t_obs_arr_sats,
ran_key_sats,
lgt0,
target_mean_log_mahs_sats,
target_std_log_mahs_sats,
target_frac_peaked_sats,
)
return loss_cens + loss_sats


loss_and_grads_mah_moments_multibin_censat = jjit(
value_and_grad(loss_mah_moments_multibin_censat)
)


@jjit
def loss_mah_moments_multibin_censat_u_params(u_params, loss_data):
u_params = DEFAULT_DIFFMAHPOP_U_PARAMS._replace(**u_params._asdict())
params = get_diffmahpop_params_from_u_params(u_params)
return loss_mah_moments_multibin_censat(params, *loss_data)


loss_and_grads_mah_moments_multibin_censat_u_params = jjit(
value_and_grad(loss_mah_moments_multibin_censat_u_params)
)
47 changes: 39 additions & 8 deletions diffmah/diffmahpop_kernels/bimod_censat_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
logtc_bimod,
)
from .bimod_logm0_kernels import logm0_pop_bimod
from .bimod_logm0_sats import logm0_pop_bimod_sats
from .t_peak_kernels import tp_pdf_cens_flex, tp_pdf_sats

DEFAULT_DIFFMAHPOP_PDICT = OrderedDict()
COMPONENT_PDICTS = (
tp_pdf_cens_flex.DEFAULT_TPCENS_PDICT,
tp_pdf_sats.DEFAULT_TP_SATS_PDICT,
logm0_pop_bimod.DEFAULT_LOGM0_PDICT,
logm0_pop_bimod_sats.DEFAULT_LOGM0_PDICT,
logtc_bimod.LOGTC_PDICT,
early_index_bimod.EARLY_INDEX_PDICT,
late_index_bimod.LATE_INDEX_PDICT,
Expand All @@ -36,6 +38,7 @@
tp_pdf_cens_flex.DEFAULT_TPCENS_U_PARAMS._asdict(),
tp_pdf_sats.DEFAULT_TP_SATS_U_PARAMS._asdict(),
logm0_pop_bimod.DEFAULT_LOGM0POP_U_PARAMS._asdict(),
logm0_pop_bimod_sats.DEFAULT_LOGM0POP_U_PARAMS._asdict(),
logtc_bimod.DEFAULT_LOGTC_U_PARAMS._asdict(),
early_index_bimod.DEFAULT_EARLY_INDEX_U_PARAMS._asdict(),
late_index_bimod.DEFAULT_LATE_INDEX_U_PARAMS._asdict(),
Expand Down Expand Up @@ -66,6 +69,14 @@ def get_component_model_params(diffmahpop_params):
for key in logm0_pop_bimod.LGM0Pop_Params._fields
]
)

logm0_params_sats = logm0_pop_bimod_sats.LGM0Pop_Params(
*[
getattr(diffmahpop_params, key)
for key in logm0_pop_bimod_sats.LGM0Pop_Params._fields
]
)

logtc_params = logtc_bimod.Logtc_Params(
*[getattr(diffmahpop_params, key) for key in logtc_bimod.Logtc_Params._fields]
)
Expand Down Expand Up @@ -96,6 +107,7 @@ def get_component_model_params(diffmahpop_params):
tp_pdf_cens_flex_params,
tp_pdf_sats_params,
logm0_params,
logm0_params_sats,
logtc_params,
early_index_params,
late_index_params,
Expand Down Expand Up @@ -124,6 +136,14 @@ def get_component_model_u_params(diffmahpop_u_params):
for key in logm0_pop_bimod.LGM0Pop_UParams._fields
]
)

logm0_sats_u_params = logm0_pop_bimod_sats.LGM0Pop_UParams(
*[
getattr(diffmahpop_u_params, key)
for key in logm0_pop_bimod_sats.LGM0Pop_UParams._fields
]
)

logtc_u_params = logtc_bimod.Logtc_UParams(
*[
getattr(diffmahpop_u_params, key)
Expand Down Expand Up @@ -161,6 +181,7 @@ def get_component_model_u_params(diffmahpop_u_params):
tp_pdf_cens_flex_u_params,
tp_pdf_sats_u_params,
logm0_u_params,
logm0_sats_u_params,
logtc_u_params,
early_index_u_params,
late_index_u_params,
Expand All @@ -172,14 +193,19 @@ def get_component_model_u_params(diffmahpop_u_params):
@jjit
def get_diffmahpop_params_from_u_params(diffmahpop_u_params):
component_model_u_params = get_component_model_u_params(diffmahpop_u_params)
tpc_u_params, tps_u_params, logm0_u_params = component_model_u_params[:3]
logtc_u_params = component_model_u_params[3]
early_index_u_params, late_index_u_params = component_model_u_params[4:6]
fec_u_params, cov_u_params = component_model_u_params[6:]
tpc_u_params, tps_u_params, logm0_u_params, logm0_sats_u_params = (
component_model_u_params[:4]
)
logtc_u_params = component_model_u_params[4]
early_index_u_params, late_index_u_params = component_model_u_params[5:7]
fec_u_params, cov_u_params = component_model_u_params[7:]

tpc_params = tp_pdf_cens_flex.get_bounded_tp_cens_params(tpc_u_params)
tps_params = tp_pdf_sats.get_bounded_tp_sat_params(tps_u_params)
logm0_params = logm0_pop_bimod.get_bounded_m0pop_params(logm0_u_params)
logm0_sats_params = logm0_pop_bimod_sats.get_bounded_m0pop_params(
logm0_sats_u_params
)
logtc_params = logtc_bimod.get_bounded_logtc_params(logtc_u_params)
early_index_params = early_index_bimod.get_bounded_early_index_params(
early_index_u_params
Expand All @@ -195,6 +221,7 @@ def get_diffmahpop_params_from_u_params(diffmahpop_u_params):
tpc_params,
tps_params,
logm0_params,
logm0_sats_params,
logtc_params,
early_index_params,
late_index_params,
Expand All @@ -211,14 +238,17 @@ def get_diffmahpop_params_from_u_params(diffmahpop_u_params):
@jjit
def get_diffmahpop_u_params_from_params(diffmahpop_params):
component_model_params = get_component_model_params(diffmahpop_params)
tpc_params, tps_params, logm0_params = component_model_params[:3]
logtc_params = component_model_params[3]
early_index_params, late_index_params = component_model_params[4:6]
fec_params, cov_params = component_model_params[6:]
tpc_params, tps_params, logm0_params, logm0_sats_params = component_model_params[:4]
logtc_params = component_model_params[4]
early_index_params, late_index_params = component_model_params[5:7]
fec_params, cov_params = component_model_params[7:]

tpc_u_params = tp_pdf_cens_flex.get_unbounded_tp_cens_params(tpc_params)
tps_u_params = tp_pdf_sats.get_unbounded_tp_sat_params(tps_params)
logm0_u_params = logm0_pop_bimod.get_unbounded_m0pop_params(logm0_params)
logm0_sats_u_params = logm0_pop_bimod_sats.get_unbounded_m0pop_params(
logm0_sats_params
)
logtc_u_params = logtc_bimod.get_unbounded_logtc_params(logtc_params)
early_index_u_params = early_index_bimod.get_unbounded_early_index_params(
early_index_params
Expand All @@ -233,6 +263,7 @@ def get_diffmahpop_u_params_from_params(diffmahpop_params):
tpc_u_params,
tps_u_params,
logm0_u_params,
logm0_sats_u_params,
logtc_u_params,
early_index_u_params,
late_index_u_params,
Expand Down
10 changes: 5 additions & 5 deletions diffmah/diffmahpop_kernels/bimod_logm0_kernels/logm0_c0_early.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C0_PDICT = OrderedDict(
lgm0pop_c0_ytp_early=0.011,
lgm0pop_c0_ylo_early=-0.086,
lgm0pop_c0_clip_c0_early=0.516,
lgm0pop_c0_clip_c1_early=-0.056,
lgm0pop_c0_t_obs_x0_early=1.504,
lgm0pop_c0_ytp_early=0.020,
lgm0pop_c0_ylo_early=-0.130,
lgm0pop_c0_clip_c0_early=0.897,
lgm0pop_c0_clip_c1_early=-0.090,
lgm0pop_c0_t_obs_x0_early=1.51,
)
LGM0Pop_C0_Params = namedtuple("LGM0Pop_C0_Params", DEFAULT_LGM0POP_C0_PDICT.keys())
DEFAULT_LGM0POP_C0_PARAMS = LGM0Pop_C0_Params(**DEFAULT_LGM0POP_C0_PDICT)
Expand Down
10 changes: 5 additions & 5 deletions diffmah/diffmahpop_kernels/bimod_logm0_kernels/logm0_c0_late.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C0_PDICT = OrderedDict(
lgm0pop_c0_ytp_late=0.012,
lgm0pop_c0_ylo_late=-0.148,
lgm0pop_c0_clip_c0_late=0.876,
lgm0pop_c0_clip_c1_late=-0.077,
lgm0pop_c0_t_obs_x0_late=2.169,
lgm0pop_c0_ytp_late=0.020,
lgm0pop_c0_ylo_late=-0.140,
lgm0pop_c0_clip_c0_late=0.89,
lgm0pop_c0_clip_c1_late=-0.090,
lgm0pop_c0_t_obs_x0_late=2.533,
)
LGM0Pop_C0_Params = namedtuple("LGM0Pop_C0_Params", DEFAULT_LGM0POP_C0_PDICT.keys())
DEFAULT_LGM0POP_C0_PARAMS = LGM0Pop_C0_Params(**DEFAULT_LGM0POP_C0_PDICT)
Expand Down
12 changes: 6 additions & 6 deletions diffmah/diffmahpop_kernels/bimod_logm0_kernels/logm0_c1_early.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C1_PDICT = OrderedDict(
lgm0pop_c1_ytp_early=0.0011,
lgm0pop_c1_ylo_early=-0.035,
lgm0pop_c1_clip_x0_early=6.130,
lgm0pop_c1_clip_ylo_early=0.147,
lgm0pop_c1_clip_yhi_early=0.0011,
lgm0pop_c1_t_obs_x0_early=3.033,
lgm0pop_c1_ytp_early=0.002,
lgm0pop_c1_ylo_early=-0.043,
lgm0pop_c1_clip_x0_early=7.185,
lgm0pop_c1_clip_ylo_early=0.140,
lgm0pop_c1_clip_yhi_early=0.002,
lgm0pop_c1_t_obs_x0_early=3.01,
)
LGM0Pop_C1_Params = namedtuple("LGM0Pop_C1_Params", DEFAULT_LGM0POP_C1_PDICT.keys())
DEFAULT_LGM0POP_C1_PARAMS = LGM0Pop_C1_Params(**DEFAULT_LGM0POP_C1_PDICT)
Expand Down
12 changes: 6 additions & 6 deletions diffmah/diffmahpop_kernels/bimod_logm0_kernels/logm0_c1_late.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from ...utils import _inverse_sigmoid, _sig_slope, _sigmoid

DEFAULT_LGM0POP_C1_PDICT = OrderedDict(
lgm0pop_c1_ytp_late=0.020,
lgm0pop_c1_ylo_late=-0.042,
lgm0pop_c1_clip_x0_late=7.855,
lgm0pop_c1_clip_ylo_late=0.149,
lgm0pop_c1_clip_yhi_late=0.005,
lgm0pop_c1_t_obs_x0_late=5.839,
lgm0pop_c1_ytp_late=0.027,
lgm0pop_c1_ylo_late=-0.048,
lgm0pop_c1_clip_x0_late=8.443,
lgm0pop_c1_clip_ylo_late=0.145,
lgm0pop_c1_clip_yhi_late=0.002,
lgm0pop_c1_t_obs_x0_late=6.377,
)
LGM0Pop_C1_Params = namedtuple("LGM0Pop_C1_Params", DEFAULT_LGM0POP_C1_PDICT.keys())
DEFAULT_LGM0POP_C1_PARAMS = LGM0Pop_C1_Params(**DEFAULT_LGM0POP_C1_PDICT)
Expand Down
Empty file.
Loading

0 comments on commit b0fe84f

Please sign in to comment.