Skip to content

Commit

Permalink
Merge pull request #82 from ArgonneCPAC/diffburst_main_tasso
Browse files Browse the repository at this point in the history
Update diffburst.py with API-breaking changes to function and parameter names
  • Loading branch information
aphearin authored Jan 9, 2024
2 parents 4702648 + 19764cd commit 2a52de8
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 74 deletions.
87 changes: 54 additions & 33 deletions dsps/sfh/diffburst.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
LGYR_PEAK_MIN = 5.0
LGAGE_MAX = 9.0
DLGAGE_MIN = 1.0
LGAGE_K = 0.1
LGYR_PEAK_MAX = LGAGE_MAX - DLGAGE_MIN

LGFBURST_MIN = -8.0
LGFBURST_MAX = -1.0

LGAGE_K = 0.1
LGFB_X0, LGFB_K = -2, 0.1
DEFAULT_LGFBURST = -4.0


class BurstParams(typing.NamedTuple):
Expand All @@ -29,7 +31,8 @@ class BurstUParams(typing.NamedTuple):
u_lgyr_max: jnp.float32


DEFAULT_PARAMS = BurstParams(-3.0, 5.5, 7.0)
DEFAULT_BURST_PARAMS = BurstParams(-3.0, 5.5, 7.0)
DEFAULT_LGFBURST = DEFAULT_BURST_PARAMS.lgfburst


@jjit
Expand Down Expand Up @@ -57,8 +60,10 @@ def calc_bursty_age_weights(burst_params, smooth_age_weights, ssp_lg_age_gyr):
"""
burst_params = BurstParams(*burst_params)

ssp_lg_age_yr = ssp_lg_age_gyr + 9
burst_weights = _age_weights_from_params(ssp_lg_age_yr, burst_params)
ssp_lg_age_yr = ssp_lg_age_gyr + 9.0
burst_weights = _pureburst_age_weights_from_params(
ssp_lg_age_yr, burst_params.lgyr_peak, burst_params.lgyr_max
)

fb = 10**burst_params.lgfburst
age_weights = fb * burst_weights + (1 - fb) * smooth_age_weights
Expand Down Expand Up @@ -104,13 +109,12 @@ def _zero_safe_normalize(x):


@jjit
def _age_weights_from_params(lgyr, burst_params):
burst_params = BurstParams(*burst_params)
dlgyr_support = burst_params.lgyr_max - burst_params.lgyr_peak
lgyr_min = burst_params.lgyr_peak - dlgyr_support
twx0 = 0.5 * (lgyr_min + burst_params.lgyr_max)
def _pureburst_age_weights_from_params(lgyr, lgyr_peak, lgyr_max):
dlgyr_support = lgyr_max - lgyr_peak
lgyr_min = lgyr_peak - dlgyr_support
twx0 = 0.5 * (lgyr_min + lgyr_max)

dlgyr = burst_params.lgyr_max - lgyr_min
dlgyr = lgyr_max - lgyr_min
twh = dlgyr / 6

tw_gauss = triweight_gaussian(lgyr, twx0, twh)
Expand All @@ -120,29 +124,45 @@ def _age_weights_from_params(lgyr, burst_params):


@jjit
def _age_weights_from_u_params(lgyr, u_burst_params):
u_burst_params = BurstUParams(*u_burst_params)
params = _get_params_from_u_params(u_burst_params)
return _age_weights_from_params(lgyr, params)
def _pureburst_age_weights_from_u_params(lgyr, u_lgyr_peak, u_lgyr_max):
lgyr_peak = _get_lgyr_peak_from_u_lgyr_peak(u_lgyr_peak)
lgyr_max = _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max)
return _pureburst_age_weights_from_params(lgyr, lgyr_peak, lgyr_max)


@jjit
def _get_params_from_u_params(u_params):
u_lgfburst, u_lgyr_peak, u_lgyr_max = u_params
lgfburst = _get_lgfburst_from_u_lgfburst(u_lgfburst)
lgyr_peak, lgyr_max = _get_tburst_params_from_tburst_u_params(
u_lgyr_peak, u_lgyr_max
)
params = BurstParams(lgfburst, lgyr_peak, lgyr_max)
return params


@jjit
def _get_tburst_params_from_tburst_u_params(u_lgyr_peak, u_lgyr_max):
lgyr_peak = _get_lgyr_peak_from_u_lgyr_peak(u_lgyr_peak)
lgyr_max = _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max)
params = lgfburst, lgyr_peak, lgyr_max
return params
return lgyr_peak, lgyr_max


@jjit
def _get_tburst_u_params_from_tburst_params(lgyr_peak, lgyr_max):
u_lgyr_peak = _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak)
u_lgyr_max = _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max)
return u_lgyr_peak, u_lgyr_max


@jjit
def _get_u_params_from_params(params):
lgfburst, lgyr_peak, lgyr_max = params
u_lgfburst = _get_u_lgfburst_from_lgfburst(lgfburst)
u_lgyr_peak = _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak)
u_lgyr_max = _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max)
u_params = u_lgfburst, u_lgyr_peak, u_lgyr_max
u_lgyr_peak, u_lgyr_max = _get_tburst_u_params_from_tburst_params(
lgyr_peak, lgyr_max
)
u_params = BurstUParams(u_lgfburst, u_lgyr_peak, u_lgyr_max)
return u_params


Expand Down Expand Up @@ -175,8 +195,7 @@ def _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max):

@jjit
def _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak):
lo = LGYR_PEAK_MIN
hi = LGAGE_MAX - DLGAGE_MIN
lo, hi = LGYR_PEAK_MIN, LGYR_PEAK_MAX
x0 = 0.5 * (hi + lo)
u_lgyr_peak = _inverse_sigmoid(lgyr_peak, x0, LGAGE_K, lo, hi)
return u_lgyr_peak
Expand All @@ -190,24 +209,26 @@ def _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max):
return u_lgyr_max


DEFAULT_U_PARAMS = BurstUParams(
*[float(u_p) for u_p in _get_u_params_from_params(DEFAULT_PARAMS)]
DEFAULT_BURST_U_PARAMS = BurstUParams(
*[float(u_p) for u_p in _get_u_params_from_params(DEFAULT_BURST_PARAMS)]
)


@jjit
def _compute_bursty_age_weights_from_params(
lgyr_since_burst, age_weights, fburst, params
):
burst_weights = _age_weights_from_params(lgyr_since_burst, params)
def _compute_bursty_age_weights_from_params(lgyr_since_burst, age_weights, params):
lgfburst, lgyr_peak, lgyr_max = params
fburst = 10**lgfburst
burst_weights = _pureburst_age_weights_from_params(
lgyr_since_burst, lgyr_peak, lgyr_max
)
age_weights = fburst * burst_weights + (1 - fburst) * age_weights
return age_weights


@jjit
def _compute_bursty_age_weights_from_u_params(
lgyr_since_burst, age_weights, fburst, u_params
):
burst_weights = _age_weights_from_u_params(lgyr_since_burst, u_params)
age_weights = fburst * burst_weights + (1 - fburst) * age_weights
def _compute_bursty_age_weights_from_u_params(lgyr_since_burst, age_weights, u_params):
params = _get_params_from_u_params(u_params)
age_weights = _compute_bursty_age_weights_from_params(
lgyr_since_burst, age_weights, params
)
return age_weights
Loading

0 comments on commit 2a52de8

Please sign in to comment.