diff --git a/dsps/sed/stellar_age_weights.py b/dsps/sed/stellar_age_weights.py index ea4607a..a0fcdee 100644 --- a/dsps/sed/stellar_age_weights.py +++ b/dsps/sed/stellar_age_weights.py @@ -1,9 +1,10 @@ """Kernels calculating stellar age PDF-weighting of SSP tempates""" -from jax import numpy as jnp from jax import jit as jjit -from ..utils import _jax_get_dt_array -from ..constants import SFR_MIN, T_BIRTH_MIN, N_T_LGSM_INTEGRATION +from jax import numpy as jnp + +from ..constants import N_T_LGSM_INTEGRATION, SFR_MIN, T_BIRTH_MIN from ..cosmology import TODAY +from ..utils import _jax_get_dt_array, cumulative_mstar_formed __all__ = ("calc_age_weights_from_sfh_table",) @@ -170,9 +171,7 @@ def _calc_logsm_table_from_sfh_table(gal_t_table, gal_sfr_table, sfr_min): Minimum star formation rate in Msun/yr """ - dt_table = _jax_get_dt_array(gal_t_table) - gal_sfr_table = jnp.where(gal_sfr_table < sfr_min, sfr_min, gal_sfr_table) - gal_mstar_table = jnp.cumsum(gal_sfr_table * dt_table) * 1e9 - logsm_table = jnp.log10(gal_mstar_table) + gal_smh_table = cumulative_mstar_formed(gal_t_table, gal_sfr_table) + logsm_table = jnp.log10(gal_smh_table) return logsm_table diff --git a/dsps/sed/tests/test_csp_sed.py b/dsps/sed/tests/test_csp_sed.py index 14717e0..597af82 100644 --- a/dsps/sed/tests/test_csp_sed.py +++ b/dsps/sed/tests/test_csp_sed.py @@ -2,10 +2,12 @@ """ import numpy as np from jax import random as jran -from ..stellar_sed import calc_rest_sed_sfh_table_lognormal_mdf -from ..stellar_sed import calc_rest_sed_sfh_table_met_table -from ...constants import T_BIRTH_MIN +from ...constants import T_TABLE_MIN +from ..stellar_sed import ( + calc_rest_sed_sfh_table_lognormal_mdf, + calc_rest_sed_sfh_table_met_table, +) SEED = 43 FSPS_LG_AGES = np.arange(5.5, 10.2, 0.05) # log10 ages in years @@ -15,7 +17,7 @@ def test_calc_rest_sed_lognormal_mdf(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key, sed_key = jran.split(ran_key, 3) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,)) @@ -63,7 +65,7 @@ def test_calc_rest_sed_lgmet_table(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key, sed_key = jran.split(ran_key, 3) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,)) diff --git a/dsps/sed/tests/test_ssp_weights.py b/dsps/sed/tests/test_ssp_weights.py index 8e48f9e..d8f31ee 100644 --- a/dsps/sed/tests/test_ssp_weights.py +++ b/dsps/sed/tests/test_ssp_weights.py @@ -2,10 +2,12 @@ """ import numpy as np from jax import random as jran -from ..ssp_weights import calc_ssp_weights_sfh_table_lognormal_mdf -from ..ssp_weights import calc_ssp_weights_sfh_table_met_table -from ...constants import T_BIRTH_MIN +from ...constants import T_TABLE_MIN +from ..ssp_weights import ( + calc_ssp_weights_sfh_table_lognormal_mdf, + calc_ssp_weights_sfh_table_met_table, +) SEED = 43 FSPS_LG_AGES = np.arange(5.5, 10.2, 0.05) # log10 ages in years @@ -15,7 +17,7 @@ def test_calc_ssp_weights_lognormal_mdf(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key = jran.split(ran_key, 2) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,)) @@ -59,7 +61,7 @@ def test_calc_ssp_weights_met_table(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key = jran.split(ran_key, 2) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,))