diff --git a/dsps/sfh/diffburst.py b/dsps/sfh/diffburst.py index 4898bcb..500ffbb 100644 --- a/dsps/sfh/diffburst.py +++ b/dsps/sfh/diffburst.py @@ -10,11 +10,13 @@ LGYR_PEAK_MIN = 5.0 LGAGE_MAX = 9.0 DLGAGE_MIN = 1.0 -LGAGE_K = 0.1 +LGYR_PEAK_MAX = LGAGE_MAX - DLGAGE_MIN + LGFBURST_MIN = -8.0 LGFBURST_MAX = -1.0 + +LGAGE_K = 0.1 LGFB_X0, LGFB_K = -2, 0.1 -DEFAULT_LGFBURST = -4.0 class BurstParams(typing.NamedTuple): @@ -29,7 +31,8 @@ class BurstUParams(typing.NamedTuple): u_lgyr_max: jnp.float32 -DEFAULT_PARAMS = BurstParams(-3.0, 5.5, 7.0) +DEFAULT_BURST_PARAMS = BurstParams(-3.0, 5.5, 7.0) +DEFAULT_LGFBURST = DEFAULT_BURST_PARAMS.lgfburst @jjit @@ -57,8 +60,10 @@ def calc_bursty_age_weights(burst_params, smooth_age_weights, ssp_lg_age_gyr): """ burst_params = BurstParams(*burst_params) - ssp_lg_age_yr = ssp_lg_age_gyr + 9 - burst_weights = _age_weights_from_params(ssp_lg_age_yr, burst_params) + ssp_lg_age_yr = ssp_lg_age_gyr + 9.0 + burst_weights = _pureburst_age_weights_from_params( + ssp_lg_age_yr, burst_params.lgyr_peak, burst_params.lgyr_max + ) fb = 10**burst_params.lgfburst age_weights = fb * burst_weights + (1 - fb) * smooth_age_weights @@ -104,13 +109,12 @@ def _zero_safe_normalize(x): @jjit -def _age_weights_from_params(lgyr, burst_params): - burst_params = BurstParams(*burst_params) - dlgyr_support = burst_params.lgyr_max - burst_params.lgyr_peak - lgyr_min = burst_params.lgyr_peak - dlgyr_support - twx0 = 0.5 * (lgyr_min + burst_params.lgyr_max) +def _pureburst_age_weights_from_params(lgyr, lgyr_peak, lgyr_max): + dlgyr_support = lgyr_max - lgyr_peak + lgyr_min = lgyr_peak - dlgyr_support + twx0 = 0.5 * (lgyr_min + lgyr_max) - dlgyr = burst_params.lgyr_max - lgyr_min + dlgyr = lgyr_max - lgyr_min twh = dlgyr / 6 tw_gauss = triweight_gaussian(lgyr, twx0, twh) @@ -120,29 +124,45 @@ def _age_weights_from_params(lgyr, burst_params): @jjit -def _age_weights_from_u_params(lgyr, u_burst_params): - u_burst_params = BurstUParams(*u_burst_params) - params = _get_params_from_u_params(u_burst_params) - return _age_weights_from_params(lgyr, params) +def _pureburst_age_weights_from_u_params(lgyr, u_lgyr_peak, u_lgyr_max): + lgyr_peak = _get_lgyr_peak_from_u_lgyr_peak(u_lgyr_peak) + lgyr_max = _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max) + return _pureburst_age_weights_from_params(lgyr, lgyr_peak, lgyr_max) @jjit def _get_params_from_u_params(u_params): u_lgfburst, u_lgyr_peak, u_lgyr_max = u_params lgfburst = _get_lgfburst_from_u_lgfburst(u_lgfburst) + lgyr_peak, lgyr_max = _get_tburst_params_from_tburst_u_params( + u_lgyr_peak, u_lgyr_max + ) + params = BurstParams(lgfburst, lgyr_peak, lgyr_max) + return params + + +@jjit +def _get_tburst_params_from_tburst_u_params(u_lgyr_peak, u_lgyr_max): lgyr_peak = _get_lgyr_peak_from_u_lgyr_peak(u_lgyr_peak) lgyr_max = _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max) - params = lgfburst, lgyr_peak, lgyr_max - return params + return lgyr_peak, lgyr_max + + +@jjit +def _get_tburst_u_params_from_tburst_params(lgyr_peak, lgyr_max): + u_lgyr_peak = _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak) + u_lgyr_max = _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max) + return u_lgyr_peak, u_lgyr_max @jjit def _get_u_params_from_params(params): lgfburst, lgyr_peak, lgyr_max = params u_lgfburst = _get_u_lgfburst_from_lgfburst(lgfburst) - u_lgyr_peak = _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak) - u_lgyr_max = _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max) - u_params = u_lgfburst, u_lgyr_peak, u_lgyr_max + u_lgyr_peak, u_lgyr_max = _get_tburst_u_params_from_tburst_params( + lgyr_peak, lgyr_max + ) + u_params = BurstUParams(u_lgfburst, u_lgyr_peak, u_lgyr_max) return u_params @@ -175,8 +195,7 @@ def _get_lgyr_max_from_lgyr_peak(lgyr_peak, u_lgyr_max): @jjit def _get_u_lgyr_peak_from_lgyr_peak(lgyr_peak): - lo = LGYR_PEAK_MIN - hi = LGAGE_MAX - DLGAGE_MIN + lo, hi = LGYR_PEAK_MIN, LGYR_PEAK_MAX x0 = 0.5 * (hi + lo) u_lgyr_peak = _inverse_sigmoid(lgyr_peak, x0, LGAGE_K, lo, hi) return u_lgyr_peak @@ -190,24 +209,26 @@ def _get_u_lgyr_max_from_lgyr_peak(lgyr_peak, lgyr_max): return u_lgyr_max -DEFAULT_U_PARAMS = BurstUParams( - *[float(u_p) for u_p in _get_u_params_from_params(DEFAULT_PARAMS)] +DEFAULT_BURST_U_PARAMS = BurstUParams( + *[float(u_p) for u_p in _get_u_params_from_params(DEFAULT_BURST_PARAMS)] ) @jjit -def _compute_bursty_age_weights_from_params( - lgyr_since_burst, age_weights, fburst, params -): - burst_weights = _age_weights_from_params(lgyr_since_burst, params) +def _compute_bursty_age_weights_from_params(lgyr_since_burst, age_weights, params): + lgfburst, lgyr_peak, lgyr_max = params + fburst = 10**lgfburst + burst_weights = _pureburst_age_weights_from_params( + lgyr_since_burst, lgyr_peak, lgyr_max + ) age_weights = fburst * burst_weights + (1 - fburst) * age_weights return age_weights @jjit -def _compute_bursty_age_weights_from_u_params( - lgyr_since_burst, age_weights, fburst, u_params -): - burst_weights = _age_weights_from_u_params(lgyr_since_burst, u_params) - age_weights = fburst * burst_weights + (1 - fburst) * age_weights +def _compute_bursty_age_weights_from_u_params(lgyr_since_burst, age_weights, u_params): + params = _get_params_from_u_params(u_params) + age_weights = _compute_bursty_age_weights_from_params( + lgyr_since_burst, age_weights, params + ) return age_weights diff --git a/dsps/sfh/tests/test_diffburst.py b/dsps/sfh/tests/test_diffburst.py index 6fa0e61..6562e34 100644 --- a/dsps/sfh/tests/test_diffburst.py +++ b/dsps/sfh/tests/test_diffburst.py @@ -1,26 +1,33 @@ """ """ import numpy as np +from jax import random as jran from .. import diffburst as db +TOL = 1e-2 + def test_params_invert(): + ran_key = jran.PRNGKey(0) n_tests = 10 for __ in range(n_tests): - u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS)) + ran_key, u_key = jran.split(ran_key, 2) + u = jran.uniform( + u_key, minval=-10, maxval=10, shape=(len(db.DEFAULT_BURST_PARAMS),) + ) + u_p = db.BurstUParams(*u) p = db._get_params_from_u_params(u_p) u_p2 = db._get_u_params_from_params(p) - assert np.allclose(u_p, u_p2, rtol=1e-3) + assert np.allclose(u_p, u_p2, rtol=TOL) def test_age_weights_is_finite_and_zero_for_edge_cases(): lgyr_peak, lgyr_max = 6.5, 8.0 - params = db.DEFAULT_LGFBURST, lgyr_peak, lgyr_max # all times in the lgyr_table are after lgyr_max lgyr_test = np.linspace(lgyr_max, lgyr_max + 2, 100) - age_weights = db._age_weights_from_params(lgyr_test, params) + age_weights = db._pureburst_age_weights_from_params(lgyr_test, lgyr_peak, lgyr_max) assert np.all(np.isfinite(age_weights)) assert np.allclose(age_weights, 0.0) @@ -28,129 +35,145 @@ def test_age_weights_is_finite_and_zero_for_edge_cases(): dlgyr_support = lgyr_max - lgyr_peak lgyr_min = lgyr_peak - dlgyr_support lgyr_test = np.linspace(lgyr_min - 2, lgyr_min, 100) - age_weights = db._age_weights_from_params(lgyr_test, params) + age_weights = db._pureburst_age_weights_from_params(lgyr_test, lgyr_peak, lgyr_max) assert np.all(np.isfinite(age_weights)) assert np.allclose(age_weights, 0.0) -def test_age_weights_from_params_scales_with_lgyr_max_as_expected(): +def test_pureburst_age_weights_from_params_scales_with_lgyr_max_as_expected(): n_tests = 10 lgyr_peak = 5.5 + ran_key = jran.PRNGKey(0) for __ in range(n_tests): - lgyr_max = np.random.uniform(7, 8) + ran_key, lgyr_max_key = jran.split(ran_key, 2) + lgyr_max = jran.uniform(lgyr_max_key, minval=7, maxval=8, shape=()) params = db.BurstParams(db.DEFAULT_LGFBURST, lgyr_peak, lgyr_max) lgyr = np.linspace(4, 10.5, 100) - age_weights = db._age_weights_from_params(lgyr, params) + age_weights = db._pureburst_age_weights_from_params(lgyr, lgyr_peak, lgyr_max) zmsk = lgyr > params.lgyr_max assert np.all(age_weights[zmsk] == 0) assert np.any(age_weights > 0) - age_weight_at_lgyr_peak = db._age_weights_from_params(lgyr_peak, params) + age_weight_at_lgyr_peak = db._pureburst_age_weights_from_params( + lgyr_peak, lgyr_peak, lgyr_max + ) assert age_weight_at_lgyr_peak > 0 lgyr_post_peak = np.linspace(lgyr_peak, lgyr_peak + 1, 10) - age_weights_post_peak = db._age_weights_from_params(lgyr_post_peak, params) + age_weights_post_peak = db._pureburst_age_weights_from_params( + lgyr_post_peak, lgyr_peak, lgyr_max + ) assert np.all(np.diff(age_weights_post_peak) < 0) lgyr_peak = 5.5 lgyr_test = np.linspace(lgyr_peak, lgyr_peak + 1, 100) - age_weight_younger = db._age_weights_from_params( - lgyr_test, (db.DEFAULT_LGFBURST, lgyr_peak, 6.5) - ) - age_weights_older = db._age_weights_from_params( - lgyr_test, (db.DEFAULT_LGFBURST, lgyr_peak, 8.0) + age_weight_younger = db._pureburst_age_weights_from_params( + lgyr_test, lgyr_peak, 6.5 ) + age_weights_older = db._pureburst_age_weights_from_params(lgyr_test, lgyr_peak, 8.0) assert age_weight_younger[0] > age_weights_older[0] def test_age_weights_from_default_params_are_weights(): lgyr = np.arange(5.5, 10.35, 0.05) - age_weights = db._age_weights_from_params(lgyr, db.DEFAULT_PARAMS) + lgyr_peak, lgyr_max = db.DEFAULT_BURST_PARAMS[1:] + age_weights = db._pureburst_age_weights_from_params(lgyr, lgyr_peak, lgyr_max) assert np.all(np.isfinite(age_weights)) assert np.all(age_weights >= 0) assert np.any(age_weights > 0) - assert np.allclose(age_weights.sum(), 1.0, rtol=1e-3) + assert np.allclose(age_weights.sum(), 1.0, rtol=TOL) def test_age_weights_from_random_params_are_weights(): + ran_key = jran.PRNGKey(0) lgyr = np.arange(5.5, 10.35, 0.05) n_tests = 10 for __ in range(n_tests): - u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS)) - age_weights = db._age_weights_from_u_params(lgyr, u_p) + ran_key, u_key = jran.split(ran_key, 2) + u_p = jran.uniform(u_key, minval=-10, maxval=10, shape=(2,)) + age_weights = db._pureburst_age_weights_from_u_params(lgyr, *u_p) assert np.all(np.isfinite(age_weights)) assert np.all(age_weights >= 0) assert np.any(age_weights > 0) - assert np.allclose(age_weights.sum(), 1.0, rtol=1e-3) + assert np.allclose(age_weights.sum(), 1.0, rtol=TOL) def test_age_weights_from_random_params_u_params_consistency(): + ran_key = jran.PRNGKey(0) lgyr = np.arange(5.5, 10.35, 0.05) n_tests = 10 for __ in range(n_tests): - u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS)) - age_weights = db._age_weights_from_u_params(lgyr, u_p) + ran_key, u_key = jran.split(ran_key, 2) + u_p = jran.uniform(u_key, minval=-10, maxval=10, shape=(3,)) p = db._get_params_from_u_params(u_p) - age_weights2 = db._age_weights_from_params(lgyr, p) - assert np.allclose(age_weights, age_weights2, rtol=1e-3) + age_weights = db._pureburst_age_weights_from_u_params(lgyr, *u_p[1:]) + age_weights2 = db._pureburst_age_weights_from_params(lgyr, *p[1:]) + assert np.allclose(age_weights, age_weights2, rtol=TOL) def test_compute_bursty_age_weights_from_params(): + ran_key = jran.PRNGKey(0) lgyr_since_burst = np.arange(5.5, 10.35, 0.05) n_age = lgyr_since_burst.size - age_weights = np.random.uniform(0, 1, n_age) + age_weights = jran.uniform(ran_key, minval=0, maxval=1, shape=(n_age,)) age_weights = age_weights / age_weights.sum() - fburst = np.random.uniform(0, 1) bursty_age_weights = db._compute_bursty_age_weights_from_params( - lgyr_since_burst, age_weights, fburst, db.DEFAULT_PARAMS + lgyr_since_burst, age_weights, db.DEFAULT_BURST_PARAMS ) assert np.all(np.isfinite(bursty_age_weights)) assert np.all(bursty_age_weights >= 0) assert np.all(bursty_age_weights <= 1) - assert np.allclose(bursty_age_weights.sum(), 1.0, rtol=1e-3) + assert np.allclose(bursty_age_weights.sum(), 1.0, rtol=TOL) def test_compute_bursty_age_weights_from_u_params(): + ran_key = jran.PRNGKey(0) lgyr_since_burst = np.arange(5.5, 10.35, 0.05) n_tests = 10 for __ in range(n_tests): - u_p = np.random.uniform(-10, 10, len(db.DEFAULT_PARAMS)) + ran_key, u_key, weights_key = jran.split(ran_key, 3) + u = jran.uniform( + u_key, minval=-10, maxval=10, shape=(len(db.DEFAULT_BURST_PARAMS),) + ) + u_p = db.BurstUParams(*u) n_age = lgyr_since_burst.size - age_weights = np.random.uniform(0, 1, n_age) + age_weights = jran.uniform(weights_key, minval=0, maxval=1, shape=(n_age,)) age_weights = age_weights / age_weights.sum() - fburst = np.random.uniform(0, 1) bursty_age_weights = db._compute_bursty_age_weights_from_u_params( - lgyr_since_burst, age_weights, fburst, u_p + lgyr_since_burst, age_weights, u_p ) assert np.all(np.isfinite(bursty_age_weights)) assert np.all(bursty_age_weights >= 0) assert np.all(bursty_age_weights <= 1) - assert np.allclose(bursty_age_weights.sum(), 1.0, rtol=1e-3) + assert np.allclose(bursty_age_weights.sum(), 1.0, rtol=TOL) p = db._get_params_from_u_params(u_p) bursty_age_weights2 = db._compute_bursty_age_weights_from_params( - lgyr_since_burst, age_weights, fburst, p + lgyr_since_burst, age_weights, p ) assert np.all(np.isfinite(bursty_age_weights2)) assert np.all(bursty_age_weights2 >= 0) assert np.all(bursty_age_weights2 <= 1) - assert np.allclose(bursty_age_weights2.sum(), 1.0, rtol=1e-3) + assert np.allclose(bursty_age_weights2.sum(), 1.0, rtol=TOL) - assert np.allclose(bursty_age_weights, bursty_age_weights2, rtol=1e-3) + assert np.allclose(bursty_age_weights, bursty_age_weights2, rtol=TOL) def test_calc_bursty_age_weights(): - burst_params = db.DEFAULT_PARAMS + ran_key = jran.PRNGKey(0) + burst_params = db.DEFAULT_BURST_PARAMS n_age = 107 ssp_lg_age_gyr = np.linspace(5.0 - 9, 10.5 - 9, n_age) - smooth_age_weights = np.random.uniform(0, 1, n_age) + smooth_age_weights = jran.uniform(ran_key, minval=0, maxval=1, shape=(n_age,)) smooth_age_weights = smooth_age_weights / smooth_age_weights.sum() assert np.allclose(smooth_age_weights.sum(), 1.0, rtol=1e-4) ssp_lg_age_yr = ssp_lg_age_gyr + 9.0 - burst_weights = db._age_weights_from_params(ssp_lg_age_yr, burst_params) + burst_weights = db._pureburst_age_weights_from_params( + ssp_lg_age_yr, *burst_params[1:] + ) assert np.allclose(burst_weights.sum(), 1.0, rtol=1e-4) bursty_age_weights = db.calc_bursty_age_weights( @@ -162,13 +185,19 @@ def test_calc_bursty_age_weights(): def test_calc_bursty_age_weights_from_u_params(): + ran_key = jran.PRNGKey(0) n_age = 107 ssp_lg_age_gyr = np.linspace(5.0 - 9, 10.5 - 9, n_age) - smooth_age_weights = np.random.uniform(0, 1, n_age) + smooth_key, ran_key = jran.split(ran_key, 2) + smooth_age_weights = jran.uniform(smooth_key, minval=0, maxval=1, shape=(n_age,)) smooth_age_weights = smooth_age_weights / smooth_age_weights.sum() n_tests = 10 for i in range(n_tests): - u_burst_params = np.random.uniform(-20, 20, len(db.DEFAULT_PARAMS)) + u_key, ran_key = jran.split(ran_key, 2) + u = jran.uniform( + u_key, minval=-20, maxval=20, shape=(len(db.DEFAULT_BURST_PARAMS),) + ) + u_burst_params = db.BurstUParams(*u) bursty_age_weights = db.calc_bursty_age_weights_from_u_params( u_burst_params, smooth_age_weights, ssp_lg_age_gyr )