From 12f526ee1b5fc8ae9f6c787943e18b14a62cc346 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Tue, 10 Oct 2023 09:47:57 -0500 Subject: [PATCH 1/2] Implement cumulative_mstar_formed in calculation of distribution of stellar ages. Numerous tests now fail on account of NaNs --- dsps/sed/stellar_age_weights.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dsps/sed/stellar_age_weights.py b/dsps/sed/stellar_age_weights.py index ea4607a..a0fcdee 100644 --- a/dsps/sed/stellar_age_weights.py +++ b/dsps/sed/stellar_age_weights.py @@ -1,9 +1,10 @@ """Kernels calculating stellar age PDF-weighting of SSP tempates""" -from jax import numpy as jnp from jax import jit as jjit -from ..utils import _jax_get_dt_array -from ..constants import SFR_MIN, T_BIRTH_MIN, N_T_LGSM_INTEGRATION +from jax import numpy as jnp + +from ..constants import N_T_LGSM_INTEGRATION, SFR_MIN, T_BIRTH_MIN from ..cosmology import TODAY +from ..utils import _jax_get_dt_array, cumulative_mstar_formed __all__ = ("calc_age_weights_from_sfh_table",) @@ -170,9 +171,7 @@ def _calc_logsm_table_from_sfh_table(gal_t_table, gal_sfr_table, sfr_min): Minimum star formation rate in Msun/yr """ - dt_table = _jax_get_dt_array(gal_t_table) - gal_sfr_table = jnp.where(gal_sfr_table < sfr_min, sfr_min, gal_sfr_table) - gal_mstar_table = jnp.cumsum(gal_sfr_table * dt_table) * 1e9 - logsm_table = jnp.log10(gal_mstar_table) + gal_smh_table = cumulative_mstar_formed(gal_t_table, gal_sfr_table) + logsm_table = jnp.log10(gal_smh_table) return logsm_table From f3001e68321546b3bb621f094f1b5ca771f66403 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Fri, 3 Nov 2023 15:37:58 -0500 Subject: [PATCH 2/2] Fix failing tests. Need to remember that gal_t_table should use T_TABLE_MIN and not T_BIRTH_MIN --- dsps/sed/tests/test_csp_sed.py | 12 +++++++----- dsps/sed/tests/test_ssp_weights.py | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/dsps/sed/tests/test_csp_sed.py b/dsps/sed/tests/test_csp_sed.py index 14717e0..597af82 100644 --- a/dsps/sed/tests/test_csp_sed.py +++ b/dsps/sed/tests/test_csp_sed.py @@ -2,10 +2,12 @@ """ import numpy as np from jax import random as jran -from ..stellar_sed import calc_rest_sed_sfh_table_lognormal_mdf -from ..stellar_sed import calc_rest_sed_sfh_table_met_table -from ...constants import T_BIRTH_MIN +from ...constants import T_TABLE_MIN +from ..stellar_sed import ( + calc_rest_sed_sfh_table_lognormal_mdf, + calc_rest_sed_sfh_table_met_table, +) SEED = 43 FSPS_LG_AGES = np.arange(5.5, 10.2, 0.05) # log10 ages in years @@ -15,7 +17,7 @@ def test_calc_rest_sed_lognormal_mdf(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key, sed_key = jran.split(ran_key, 3) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,)) @@ -63,7 +65,7 @@ def test_calc_rest_sed_lgmet_table(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key, sed_key = jran.split(ran_key, 3) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,)) diff --git a/dsps/sed/tests/test_ssp_weights.py b/dsps/sed/tests/test_ssp_weights.py index 8e48f9e..d8f31ee 100644 --- a/dsps/sed/tests/test_ssp_weights.py +++ b/dsps/sed/tests/test_ssp_weights.py @@ -2,10 +2,12 @@ """ import numpy as np from jax import random as jran -from ..ssp_weights import calc_ssp_weights_sfh_table_lognormal_mdf -from ..ssp_weights import calc_ssp_weights_sfh_table_met_table -from ...constants import T_BIRTH_MIN +from ...constants import T_TABLE_MIN +from ..ssp_weights import ( + calc_ssp_weights_sfh_table_lognormal_mdf, + calc_ssp_weights_sfh_table_met_table, +) SEED = 43 FSPS_LG_AGES = np.arange(5.5, 10.2, 0.05) # log10 ages in years @@ -15,7 +17,7 @@ def test_calc_ssp_weights_lognormal_mdf(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key = jran.split(ran_key, 2) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,)) @@ -59,7 +61,7 @@ def test_calc_ssp_weights_met_table(): ran_key = jran.PRNGKey(SEED) t_obs = 13.0 n_t = 500 - gal_t_table = np.linspace(T_BIRTH_MIN, t_obs, n_t) + gal_t_table = np.linspace(T_TABLE_MIN, t_obs, n_t) sfr_key, met_key = jran.split(ran_key, 2) gal_sfr_table = jran.uniform(sfr_key, minval=0, maxval=10, shape=(n_t,))